#!/usr/bin/env python3

import pandas as pd
import requests
import xmltodict
import time
import matplotlib.pyplot as plt
import os
import seaborn as sns
import scipy.stats as stats
import argparse
from tqdm import tqdm
import logging
from typing import Tuple, Dict, Optional, List

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Constants
NCBI_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
METADATA_FOLDER_NAME = "metadata_output"
FIGURES_FOLDER_NAME = "figures"
SEQUENCE_FOLDER_NAME = "sequence"

# Cache to store fetched metadata
metadata_cache: Dict[str, Tuple] = {}


def load_data(input_file: str) -> pd.DataFrame:
    """Load the TSV file into a DataFrame."""
    try:
        df = pd.read_csv(input_file, sep='\t')
        logging.info(f"Data loaded successfully from {input_file}")
        return df
    except Exception as e:
        logging.error(f"Error loading data from {input_file}: {e}")
        raise


def filter_data(df: pd.DataFrame, checkm_threshold: float) -> pd.DataFrame:
    """Filter the DataFrame based on CheckM completeness and ANI Check status."""
    try:
        filtered_df = df[
            df["CheckM completeness"].notna() &
            (df["CheckM completeness"] > checkm_threshold) &
            (df["ANI Check status"] == "OK")
        ]
        logging.info(f"Data filtered with CheckM threshold {checkm_threshold}")
        return filtered_df
    except Exception as e:
        logging.error(f"Error filtering data: {e}")
        raise


def create_output_directory(output_directory: str, organism_name: str) -> Tuple[str, str, str, str]:
    """Create the output directory and subdirectories."""
    try:
        organism_folder = os.path.join(output_directory, organism_name.replace(" ", "_"))
        metadata_folder = os.path.join(organism_folder, METADATA_FOLDER_NAME)
        figures_folder = os.path.join(organism_folder, FIGURES_FOLDER_NAME)
        sequence_folder = os.path.join(organism_folder, SEQUENCE_FOLDER_NAME)

        os.makedirs(metadata_folder, exist_ok=True)
        os.makedirs(figures_folder, exist_ok=True)
        os.makedirs(sequence_folder, exist_ok=True)

        logging.info(f"Output directories created: {organism_folder}")
        return organism_folder, metadata_folder, figures_folder, sequence_folder
    except Exception as e:
        logging.error(f"Error creating output directories: {e}")
        raise



def fetch_metadata(biosample_id: str, sleep_time: float) -> Tuple:
    """Fetch metadata from NCBI."""
    if biosample_id in metadata_cache:
        return metadata_cache[biosample_id]

    url = f"{NCBI_URL}?db=biosample&id={biosample_id}&retmode=xml"
    try:
        response = requests.get(url)
        response.raise_for_status()
        time.sleep(sleep_time)
        data = xmltodict.parse(response.text)

        if not data.get("BioSampleSet"):
            logging.warning(f"No 'BioSampleSet' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        biosample = data["BioSampleSet"].get("BioSample")
        if not biosample:
            logging.warning(f"No 'BioSample' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        attributes = biosample.get("Attributes", {}).get("Attribute", [])
        if not attributes:
            logging.warning(f"No 'Attributes' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        # Extract metadata
        isolation_source = collection_date = geo_location = host = pd.NA
        for attr in attributes:
            if isinstance(attr, dict):
                if attr.get("@attribute_name") == "isolation_source":
                    isolation_source = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "collection_date":
                    collection_date = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "geo_loc_name":
                    geo_location = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "host":
                    host = attr.get("#text", pd.NA)

        metadata_cache[biosample_id] = (isolation_source, collection_date, geo_location, host)
        return isolation_source, collection_date, geo_location, host
    except requests.exceptions.RequestException as e:
        logging.error(f"Network error fetching BioSample {biosample_id}: {e}")
        return pd.NA, pd.NA, pd.NA, pd.NA
    except Exception as e:
        logging.error(f"Unexpected error fetching BioSample {biosample_id}: {e}")
        return pd.NA, pd.NA, pd.NA, pd.NA


def standardize_date(date: str) -> str:
    """Standardize the 'Collection Date' column."""
    if pd.isna(date) or date in ["unknown", "missing", "NA", "not collected"]:
        return "absent"
    try:
        year = str(date).split("-")[0]
        if year.isdigit() and len(year) == 4:
            return year
        else:
            return "absent"
    except:
        return "absent"


def standardize_location(location: str) -> str:
    """Standardize the 'Geographic Location' column."""
    if pd.isna(location) or location.lower() in ["missing", "unknown", "not applicable", "not collected"]:
        return "absent"
    try:
        country = location.split(":")[0].strip()
        return country
    except:
        return "absent"


def standardize_host(host: str) -> str:
    """Standardize the 'Host' column."""
    if pd.isna(host) or host in ["unknown", "missing", "not applicable", "not collected", ""]:
        return "absent"
    return host


def save_summary(df: pd.DataFrame, output_file: str) -> None:
    """Save the DataFrame to a TSV file."""
    try:
        df.to_csv(output_file, sep='\t', index=False)
        logging.info(f"Data saved to {output_file}")
    except Exception as e:
        logging.error(f"Error saving data to {output_file}: {e}")
        raise


def plot_bar_charts(variable: str, frequency: pd.Series, percentage: pd.Series, figures_folder: str) -> None:
    """Generate and save bar plots for a given variable."""
    try:
        plt.figure(figsize=(12, 6))

        # Frequency plot
        plt.subplot(1, 2, 1)
        sns.barplot(x=frequency.index, y=frequency.values, palette="viridis")
        plt.title(f"Frequency of {variable}")
        plt.xlabel(variable)
        plt.ylabel("Frequency")
        plt.xticks(rotation=45, ha="right")

        # Percentage plot
        plt.subplot(1, 2, 2)
        sns.barplot(x=percentage.index, y=percentage.values, palette="viridis")
        plt.title(f"Percentage of {variable}")
        plt.xlabel(variable)
        plt.ylabel("Percentage")
        plt.xticks(rotation=45, ha="right")

        plt.tight_layout()
        figure_path = os.path.join(figures_folder, f"{variable}_bar_plots.tiff")
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Bar plots saved for {variable}")
    except Exception as e:
        logging.error(f"Error generating bar plots for {variable}: {e}")
        raise


def plot_distribution(column: str, data: pd.Series, title: str, figures_folder: str) -> None:
    """Generate and save a distribution plot."""
    try:
        plt.figure(figsize=(8, 6))
        sns.histplot(data, kde=True, color="blue")
        plt.title(f"Distribution of {title}")
        plt.xlabel(title)
        plt.ylabel("Frequency")
        plt.tight_layout()
        figure_path = os.path.join(figures_folder, f"{column}_distribution.tiff")
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Distribution plot saved for {column}")
    except Exception as e:
        logging.error(f"Error generating distribution plot for {column}: {e}")
        raise


def plot_scatter_with_trend_and_corr(
    x: pd.Series, y: pd.Series, xlabel: str, ylabel: str, title: str, filename: str, figures_folder: str
) -> None:
    """Generate and save a scatter plot with a trend line and correlation coefficient."""
    try:
        plt.figure(figsize=(10, 6))

        # Ensure x and y are numeric
        x_numeric = pd.to_numeric(x, errors="coerce")
        y_numeric = pd.to_numeric(y, errors="coerce")

        # Compute Pearson correlation coefficient
        r_value, p_value = stats.pearsonr(x_numeric.dropna(), y_numeric.dropna())

        # Scatter plot with regression line
        sns.regplot(x=x_numeric, y=y_numeric, scatter_kws={"alpha": 0.5}, line_kws={"color": "red"})

        # Annotate with correlation coefficient
        plt.text(
            x_numeric.min(), y_numeric.max(),
            f"r = {r_value:.3f} (p={p_value:.3f})",
            fontsize=12, color="black", bbox=dict(facecolor="white", alpha=0.7)
        )

        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.tight_layout()
        figure_path = os.path.join(figures_folder, filename)
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Scatter plot saved: {title}")
    except Exception as e:
        logging.error(f"Error generating scatter plot for {title}: {e}")
        raise


def save_clean_data(df: pd.DataFrame, columns_to_keep: List[str], output_file: str) -> None:
    """Save a filtered DataFrame with selected columns."""
    try:
        df_filtered = df[columns_to_keep]
        df_filtered.to_csv(output_file, index=False)
        logging.info(f"Filtered dataset saved to {output_file}")
    except Exception as e:
        logging.error(f"Error saving filtered dataset to {output_file}: {e}")
        raise


def generate_metadata_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save a metadata summary."""
    try:
        summary_data = []
        for column in ["Geographic Location", "Host", "Collection Date"]:
            value_counts = df[column].value_counts()
            total = value_counts.sum()
            for value, count in value_counts.items():
                percentage = (count / total) * 100
                summary_data.append([column, value, count, f"{percentage:.2f}%"])

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Value", "Frequency", "Percentage"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Metadata summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating metadata summary: {e}")
        raise


def generate_annotation_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save an annotation summary."""
    try:
        summary_data = []
        for column in ["Annotation Count Gene Total", "Annotation Count Gene Protein-coding", "Annotation Count Gene Pseudogene"]:
            df[column] = pd.to_numeric(df[column], errors="coerce").dropna()
            if not df[column].empty:
                highest = df[column].max()
                mean = df[column].mean()
                median = df[column].median()
                lowest = df[column].min()
            else:
                highest = mean = median = lowest = "No Data"

            summary_data.append([column, "Highest", highest])
            summary_data.append([column, "Mean", mean])
            summary_data.append([column, "Median", median])
            summary_data.append([column, "Lowest", lowest])

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Summary", "Value"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Annotation summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating annotation summary: {e}")
        raise


def generate_assembly_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save an assembly summary."""
    try:
        df["Assembly Stats Total Sequence Length"] = pd.to_numeric(df["Assembly Stats Total Sequence Length"], errors="coerce").dropna()
        if not df["Assembly Stats Total Sequence Length"].empty:
            highest = df["Assembly Stats Total Sequence Length"].max()
            mean = df["Assembly Stats Total Sequence Length"].mean()
            median = df["Assembly Stats Total Sequence Length"].median()
            lowest = df["Assembly Stats Total Sequence Length"].min()
        else:
            highest = mean = median = lowest = "No Data"

        summary_data = [
            ["Assembly Stats Total Sequence Length", "Highest", highest],
            ["Assembly Stats Total Sequence Length", "Mean", mean],
            ["Assembly Stats Total Sequence Length", "Median", median],
            ["Assembly Stats Total Sequence Length", "Lowest", lowest]
        ]

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Summary", "Value"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Assembly summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating assembly summary: {e}")
        raise

def download_genome_fasta_ftp(assembly_accession: str, assembly_name: str, output_folder: str) -> None:
    """Download genome FASTA file via FTP."""
    base_url = "https://ftp.ncbi.nlm.nih.gov/genomes/all"
    accession_parts = assembly_accession.split("_")
    dir1 = accession_parts[1][:3]
    dir2 = accession_parts[1][3:6]
    dir3 = accession_parts[1][6:9]
    url = f"{base_url}/GCF/{dir1}/{dir2}/{dir3}/{assembly_accession}_{assembly_name}/{assembly_accession}_{assembly_name}_genomic.fna.gz"

    try:
        # Create the output folder if it doesn't exist
        os.makedirs(output_folder, exist_ok=True)

        # Download the .fna.gz file
        gz_filename = os.path.join(output_folder, f"{assembly_accession}_{assembly_name}_genomic.fna.gz")
        os.system(f"wget {url} -O {gz_filename}")

        # Unzip the .fna.gz file
        fna_filename = os.path.join(output_folder, f"{assembly_accession}_{assembly_name}_genomic.fna")
        os.system(f"gunzip -c {gz_filename} > {fna_filename}")

        logging.info(f"Downloaded genome FASTA for {assembly_accession} to {fna_filename}")
    except Exception as e:
        logging.error(f"Error downloading genome FASTA for {assembly_accession}: {e}")



def main():
    """Main function to execute the script."""
    parser = argparse.ArgumentParser(description="Metadata")
    parser.add_argument("--input", required=True, help="Path to the input TSV file")
    parser.add_argument("--outdir", required=True, help="Path to the output directory")
    parser.add_argument("--sleep", type=float, default=0.5, help="Time to wait between requests (default: 0.5s)")
    parser.add_argument("--checkm", type=float, default=95, help="Minimum CheckM completeness threshold (default: 95)")
    parser.add_argument("--seq", action="store_true", help="Run the script to download sequences")
    args = parser.parse_args()

    try:
        # Load and filter data
        df = load_data(args.input)
        df = filter_data(df, args.checkm)

        # Create output directories
        organism_name = df["Organism Name"].iloc[0].replace(" ", "_")
        organism_folder, metadata_folder, figures_folder, sequence_folder = create_output_directory(args.outdir, organism_name)

        # Fetch metadata
        df["Isolation Source"] = pd.NA
        df["Collection Date"] = pd.NA
        df["Geographic Location"] = pd.NA
        df["Host"] = pd.NA

        for index, row in tqdm(df.iterrows(), total=len(df), desc="Fetching metadata"):
            biosample_id = row["BioSample"]
            if pd.notna(biosample_id):
                isolation_source, collection_date, geo_location, host = fetch_metadata(biosample_id, args.sleep)
                df.at[index, "Isolation Source"] = isolation_source
                df.at[index, "Collection Date"] = collection_date
                df.at[index, "Geographic Location"] = geo_location
                df.at[index, "Host"] = host

        # Standardize columns
        df["Collection Date"] = df["Collection Date"].apply(standardize_date)
        df["Geographic Location"] = df["Geographic Location"].apply(standardize_location)
        df["Host"] = df["Host"].apply(standardize_host)

        # Save updated data
        output_file = os.path.join(metadata_folder, "ncbi_dataset_updated.tsv")
        save_summary(df, output_file)
        
        #save metadata summary
        # Sort the DataFrame to prioritize rows with "GCF" in "Assembly Accession"
        df_sorted = df.sort_values(by="Assembly Accession", key=lambda x: x.str.startswith("GCF"), ascending=False)
        # Drop duplicates based on "Assembly Name", keeping the first occurrence (which will be the "GCF" row)
        df2 = df_sorted.drop_duplicates(subset=["Assembly Name"], keep="first")
        output_file = os.path.join(metadata_folder, "metadata_summary.csv")
        generate_metadata_summary(df2, output_file)
        
        #save assembly summary
        output_file = os.path.join(metadata_folder, "assembly_summary.csv")
        generate_assembly_summary(df2, output_file)

        # Generate and save bar plots
        for variable in ["Geographic Location", "Host", "Collection Date"]:
            frequency = df2[variable].value_counts()
            percentage = (frequency / frequency.sum()) * 100
            plot_bar_charts(variable, frequency, percentage, figures_folder)

        #save annotation summary
        df3 = df[df["Annotation Name"] == "NCBI Prokaryotic Genome Annotation Pipeline (PGAP)"]
        output_file = os.path.join(metadata_folder, "annotation_summary.csv")
        generate_annotation_summary(df3, output_file)
        
        # Generate distribution plots for annotation columns
        for column in ["Annotation Count Gene Total", "Annotation Count Gene Protein-coding", "Annotation Count Gene Pseudogene"]:
            df3.loc[:, column] = pd.to_numeric(df3[column], errors="coerce")
            if not df3[column].empty:
                plot_distribution(column, df3[column], column, figures_folder)

        # Filter and sort data for scatter plots
        df3_filtered = df3[df3["Collection Date"] != "absent"].copy()
        df3_filtered["Collection Date"] = pd.to_numeric(df3_filtered["Collection Date"], errors="coerce")
        df3_filtered = df3_filtered.dropna(subset=["Collection Date"])
        df3_filtered = df3_filtered.sort_values(by="Collection Date", ascending=True)

        # Generate scatter plots with trend lines
        plot_scatter_with_trend_and_corr(
            x=df3_filtered["Collection Date"],
            y=df3_filtered["Annotation Count Gene Total"],
            xlabel="Collection Date",
            ylabel="Annotation Count Gene Total",
            title="Scatter Plot: Annotation Count Gene Total vs Collection Date",
            filename="scatter_plot_gene_total_vs_collection_date.tiff",
            figures_folder=figures_folder
        )
        plot_scatter_with_trend_and_corr(
            x=df3_filtered["Collection Date"],
            y=df3_filtered["Annotation Count Gene Protein-coding"],
            xlabel="Collection Date",
            ylabel="Annotation Count Gene Protein-coding",
            title="Scatter Plot: Annotation Count Gene Protein-coding vs Collection Date",
            filename="scatter_plot_gene_protein_coding_vs_collection_date.tiff",
            figures_folder=figures_folder
        )

        # Save clean data
        columns_to_keep = [
            "Organism Name", "BioSample", "Assembly Accession", "Assembly Name", "Assembly BioProject Accession",
            "Organism Infraspecific Names Strain", "Assembly Stats Total Sequence Length",
            "Isolation Source", "Collection Date", "Geographic Location", "Host"
        ]
        # Filter df2 with selected columns
        df4 = df2[columns_to_keep]
        clean_data_file = os.path.join(metadata_folder, "ncbi_clean.csv")
        save_clean_data(df4, columns_to_keep, clean_data_file)
        
        # If --seq is provided, download sequences
        # Define the input file path inside the metadata folder
        input_file = os.path.join(metadata_folder, "ncbi_clean.csv")

        # Check if the input file exists
        if not os.path.isfile(input_file):
            logging.error(f"Input file not found at: {input_file}")
            raise FileNotFoundError(f"Input file not found at: {input_file}")

        # Load the CSV file into a DataFrame
        df_clean = pd.read_csv(input_file)

        # Check if required columns exist
        if "Assembly Accession" not in df_clean.columns or "Assembly Name" not in df_clean.columns:
            raise ValueError("Required columns 'Assembly Accession' or 'Assembly Name' not found in the CSV file.")

        # Download genome FASTA files for each assembly accession
        for index, row in tqdm(df_clean.iterrows(), total=len(df_clean), desc="Downloading genome FASTA files"):
            assembly_accession = row["Assembly Accession"]
            assembly_name = row["Assembly Name"]
            download_genome_fasta_ftp(assembly_accession, assembly_name, sequence_folder)

        logging.info("Sequence downloading completed.")
        logging.info("Script completed successfully.")
    except Exception as e:
        logging.error(f"Script failed: {e}")


if __name__ == "__main__":
    main()
