import os
import pandas as pd

from aitana.whakaari import Gas, Seismicity
from aitana.util import eqRate, reindex

STARTDATE = pd.to_datetime("2005-01-01", utc=True)
ENDDATE = pd.to_datetime("2021-01-01", utc=True)

rule all:
    input:
        "poster/poster.pdf"

rule scopus_search_plot:
    input:
        "data/Scopus-20530-Analyze-Year.csv"
    output:
        "poster/scopus_search_plot_small.png"
    script:
        "scripts/plot_publications.py"

rule download_random_forest_model:
    output:
        git_dir=directory("decision_tree_repo/"),
        forecast_model="decision_tree_repo/scripts/forecast_model.py"
    shell:
        """
        cd {output.git_dir} && \
        git init && \
        git remote add origin https://github.com/ddempsey/whakaari.git && \
        git fetch && \
        git checkout -t origin/master
        """

rule run_random_forest_model:
    input:
        "decision_tree_repo/scripts/forecast_model.py"
    output:
        "results/decision_tree_consensus.csv"
    conda:
        "envs/whakaari_decision_tree.yaml"
    shell:
        """
        cd decision_tree_repo/scripts && python -c "from forecast_model import forecast_dec2019; forecast_dec2019()"
        cd ../.. &&
        cp decision_tree_repo/predictions/fm_2.00wndw_0.75ovlp_2.00lkfd_dsar-hf-mf-rsam/consensus.csv {output}
        """

rule seismicity:
    output:
        "data/seismicity.csv"
    run:
        s = Seismicity(STARTDATE, ENDDATE)
        cat = s.quakes()
        dft = eqRate(cat, fixed_time=7).resample("D").mean().interpolate()
        new_dates = pd.date_range(dft.index[0], ENDDATE, freq="D")
        dft = reindex(dft, new_dates)
        dft = pd.DataFrame({'Eqr': dft})
        dft.to_csv(output[0])

rule co2:
    input:
        "data/seismicity.csv"
    output:
        "data/co2.csv"
    run:
        g = Gas(STARTDATE, ENDDATE)
        co2 = g.co2()
        new_dates = pd.read_csv(input[0],
                                index_col=0, parse_dates=True).index
        co2 = reindex(co2, new_dates)
        co2 = pd.DataFrame({'CO2': co2})
        co2.to_csv(output[0])

rule so2:
    input:
        "data/seismicity.csv"
    output:
        "data/so2.csv"
    run:
        g = Gas(STARTDATE, ENDDATE)
        so2 = g.so2()
        new_dates = pd.read_csv(input[0],
                                index_col=0, parse_dates=True).index
        so2 = reindex(so2, new_dates)
        so2 = pd.DataFrame({'SO2': so2})
        so2.to_csv(output[0])

rule h2s:
    input:
        "data/seismicity.csv"
    output:
        "data/h2s.csv"
    run:
        g = Gas(STARTDATE, ENDDATE)
        h2s = g.h2s()
        new_dates = pd.read_csv(input[0],
                                index_col=0, parse_dates=True).index
        h2s = reindex(h2s, new_dates)
        h2s = pd.DataFrame({'H2S': h2s})
        h2s.to_csv(output[0])

rule rsam:
    input:
        "data/seismicity.csv"
    output:
        "data/rsam.csv"
    run:
        s = Seismicity(STARTDATE, ENDDATE)
        rsam = s.rsam()
        new_dates = pd.read_csv(input[0],
                                index_col=0, parse_dates=True).index
        rsam = reindex(rsam, new_dates)
        rsam = pd.DataFrame({'RSAM': rsam})
        rsam.to_csv(output[0])

rule combine_data:
    input:
        "data/seismicity.csv",
        "data/co2.csv",
        "data/so2.csv",
        "data/h2s.csv",
        "data/rsam.csv"
    output:
        "data/whakaari_data.csv"
    run:
        cols = []
        for f in input:
            if not os.path.exists(f):
                raise FileNotFoundError(
                    f"Required file {f} not found. Please run the workflow to generate it.")
            colname = os.path.splitext(os.path.basename(f))[0]
            df = pd.read_csv(f, index_col=0, parse_dates=True)
            cols.append(df)
        combined = pd.concat(cols, join="outer", axis=1)
        combined.to_csv(output[0])

rule assign_groups:
    input:
        "data/whakaari_data.csv" 
    output:
        "data/whakaari_data_w_groups.csv"
    conda:
        "envs/whakaari_bayesian_network.yaml"
    script:
        "scripts/assign_groups.py"

rule hyperparameter_tuning:
    input:
        "data/whakaari_data_w_groups.csv"
    output:
        "results/grid_search_results.csv"
    conda:
        "envs/whakaari_bayesian_network.yaml"
    script:
        "scripts/bn_grid_search.py"

rule bayesian_network_forecasts:
    input:
        "data/whakaari_data_w_groups.csv",
        "results/grid_search_results.csv"
    output:
        "results/bn_forecasts.nc"
    conda:
        "envs/whakaari_bayesian_network.yaml"
    script:
        "scripts/bn_forecasts.py"

rule combine_forecasts:
    input:
        "results/bn_forecasts.nc",
        "results/decision_tree_consensus.csv"
    output:
        expand("results/whakaari_forecasts/{model}.nc",
               model=["bayesian_network", "random_forest"]),
    run:
        import xarray as xr
        import pandas as pd
        from tonik import Storage
        bn_xds = xr.open_dataset(input[0])
        rf_csv = pd.read_csv(input[1], parse_dates=True, index_col=0)
        st = Storage("whakaari_forecasts", rootdir="results")
        st.save(xr.Dataset({'bayesian_network': bn_xds['probs']}, coords={"datetime": bn_xds['datetime']}))
        rf_csv.index.name = "datetime"
        rf_xds = xr.Dataset({'random_forest': rf_csv['consensus'].to_xarray()}, coords={"datetime": rf_csv.index.values})
        st.save(rf_xds)

rule plot_forecasts:
    input:
        expand("results/whakaari_forecasts/{model}.nc",
               model=["bayesian_network", "random_forest"]),
        "data/whakaari_data_w_groups.csv"
    output:
        "poster/forecast_comparison.png"
    script:
        "scripts/plot_forecasts.py"

rule plot_scoring_function:
    input:
        expand("results/whakaari_forecasts/{model}.nc",
               model=["bayesian_network", "random_forest"]),
    output:
        "poster/scoring_function.png"
    script:
        "scripts/plot_scoring_function.py"

rule plot_roc_curves:
    input:
        "data/whakaari_data.csv",
         expand("results/whakaari_forecasts/{model}.nc",
               model=["bayesian_network", "random_forest"]),
    output:
        "poster/roc_curves.png"
    script:
        "scripts/plot_roc_curves.py"

rule plot_dag:
    output:
        "poster/dag.png"
    conda:
        "envs/graphviz.yaml"
    shell:
        "snakemake --dag | dot -Tpng -Gdpi=300 > {output}"
rule generate_qr_code:
    output:
        "poster/qr_code.png"
    conda:
        "envs/qrcode.yaml"
    script:
        "scripts/generate_qr_code.py"

rule generate_poster:
    input:
        "poster/forecast_comparison.png",
        "poster/scopus_search_plot_small.png",
        "poster/dag.png",
        "poster/qr_code.png",
        "poster/scoring_function.png",
        "poster/roc_curves.png",
        "poster/poster_small.svg"
    output:
        "poster/poster.pdf"
    container:
        "docker://minidocks/inkscape:latest"
    shell:
        """
        inkscape ./poster/poster_small.svg --export-type=pdf --export-filename={output}
        """
