# -*- coding: utf-8 -*-
"""
Module c50_rules.py
----------------
C50 rules algorithm warpper;
C5.0 is a opensource decision tree algorithm written in `c` programming language.
For more information please refer to their (official website)[https://www.rulequest.com/see5-unix.html#USE].   
"""
from typing import Union
import pathlib

import subprocess as sp

from uuid import uuid4

import pandas as pd
from jinja2 import Environment

from .names_file_builder import NamesFileBuilder
from .parsers import RuleSet
from .estimator_loader import load_estimator


class C50Rules(object):
    def __init__(self, working_dir='/temp/', estimator_id=None):
        if estimator_id is None:
            self.id = str(uuid4()).rsplit('-', maxsplit=1)[-1]
        else:
            self.id = estimator_id
        
        p = pathlib.Path(working_dir)
        if not (p.exists() and p.is_dir()):
            p.mkdir() 
        
        self.working_dir = working_dir
        self._estimator = None
        # the order in which the features are defined in the .names file;
        # this attribute is initialized in the `self._write_names_file` method.
        self._features_order: list = None
        self._labels: list = None
    
    @property
    def estimator(self):
        if self._estimator is None:
            Estimator = load_estimator(f'{self.working_dir}/estimator_{self.id}.py')
            self._estimator = Estimator()
        return self._estimator

    @property
    def rules(self):
        return self.estimator.rules

    def fit(
        self, 
        x_train: pd.DataFrame, y_train: pd.Series, 
        x_test: pd.DataFrame, y_test: pd.Series,
        descrete_values_subset=False,
        winnow=False,
        disable_global_prunning=False,
        prunning_confidence_factor=.25,
        initial_tree_constraint_degree=2,
        data_sampling=None,
        sampling_seed=None,
        weights=None,
        stack_size=20000, # 20MB
    ):
        self._labels = y_train.unique().tolist()

        train_set = pd.concat([x_train, y_train], axis=1)
        test_set = pd.concat([x_test, y_test], axis=1)

        self._write_files(
            train_set=train_set,
            test_set=test_set,
            target_name=y_train.name,
            weights=weights
        )

        extra_args = self._get_extra_args(
            descrete_values_subset=descrete_values_subset,
            winnow=winnow,
            disable_global_prunning=disable_global_prunning,
            prunning_confidence_factor=prunning_confidence_factor,
            initial_tree_constraint_degree=initial_tree_constraint_degree,
            data_sampling=data_sampling,
            sampling_seed=sampling_seed,
            weights=weights,
        )
        base_path = str(pathlib.Path(__file__).parent.parent.resolve())
        bash_execute_command = f'ulimit -Ss {stack_size} && {base_path}/bin/c5.0 -f {self.working_dir}/c5.0-{self.id} -r {extra_args}'
        
        c50 = sp.run([bash_execute_command], shell=True, stdout=sp.PIPE, stderr=sp.PIPE)

        c50_output = c50.stdout.decode('utf-8')
        stderr = c50.stderr.decode('utf-8')
        with open(f'{self.working_dir}/c5.0-{self.id}.output.txt', 'w') as f:
            f.write(c50_output)

        if c50.returncode == 0:
            self._parse_output(c50_output)
        else:
            raise RuntimeError(
                "C5.0 binary execution failed with code {}\n\n{}\n\n{}".format(
                    c50.returncode, 
                    stderr,
                    c50_output
                )
            )
    
    def _get_extra_args(
        self,
        descrete_values_subset=False,
        winnow=False,
        disable_global_prunning=False,
        prunning_confidence_factor=.25,
        initial_tree_constraint_degree=2,
        data_sampling=None,
        sampling_seed=None,
        weights=None,
        
    ):
        extra_args = []
        if descrete_values_subset:
            extra_args.append('-s')
        
        if winnow:
            extra_args.append('-w')
        
        # if show_cut_threshold_info:
        #     # extra_params.append('-p')
        #     pass

        if disable_global_prunning:
            extra_args.append('-g')
        
        if prunning_confidence_factor is not None and prunning_confidence_factor != .25:
            # prunning_error_rate = 100 disables the initial prunning
            extra_args.append(f'-c {prunning_confidence_factor}')
        
        if initial_tree_constraint_degree is not None and initial_tree_constraint_degree != 2:
            extra_args.append(f'-m {initial_tree_constraint_degree}')
        
        if data_sampling is not None:
            extra_args.append(f'-S {data_sampling}')
        
        if sampling_seed is not None:
            extra_args.append(f'-I {sampling_seed}')
        
        if weights is None:
            extra_args.append('-e')
        
        return ' '.join(extra_args)

    def predict(self, X):
        return self.estimator.predict(X)

    def batch_predict(self,X):
        return self.estimator.batch_predict(X)

    def _write_names_file(self, dataset, target_name):

        builder = NamesFileBuilder()
        builder.build_and_save(
            path=f'{self.working_dir}/c5.0-{self.id}.names',
            dataset=dataset,
            target_name=target_name
        )
        self._features_order = builder.features_order
    
    def _write_data_files(self, train_set: pd.DataFrame, test_set: pd.DataFrame):
        
        train_set.loc[:, self._features_order].to_csv(
            f'{self.working_dir}/c5.0-{self.id}.data', 
            # f'/tmp/c5.0-{self.id}.data', 
            header=False, 
            index=False
        )

        if test_set is not None:
            test_set.loc[:, self._features_order].to_csv(
                f'{self.working_dir}/c5.0-{self.id}.test', 
                header=False, 
                index=False
            )
    
    def _write_weights_file(self, weights):

        template_path = pathlib.Path(__file__).parent.resolve()/'templates'/'weights.costs.jinja'
        with open(template_path, mode='rt', encoding='utf8') as f:
            weights_template = f.read()
        
        template = Environment().from_string(weights_template)
        rendered_estimator = template.render(weights=weights)

        with open(f'{self.working_dir}/c5.0-{self.id}.costs', 'w') as f:
            f.write(rendered_estimator)

    def _write_files(
        self, train_set: pd.DataFrame, 
        test_set: Union[pd.DataFrame, None], 
        target_name: str,
        weights: Union[pd.DataFrame, None]
    ):
        self._write_names_file(
            dataset=train_set,
            target_name=target_name
        )

        self._write_data_files(
            train_set=train_set,
            test_set=test_set
        )
        if weights is not None:
            self._write_weights_file(weights)

    def _parse_output(self, c50_output, rule_based=False):

        rendered_estimator = self._parse_estimator(c50_output)
        with open(f'{self.working_dir}/estimator_{self.id}.py', 'w') as f:
            f.write(rendered_estimator)

    def _parse_estimator(self, c50_output):
        rules_parser = RuleSet(self._features_order, self._labels)
        def_start = c50_output.index("Rules:")
        def_end = c50_output.index("Evaluation on training data")
        rules_def = c50_output[def_start: def_end].replace("Rules:", "").strip()
        
        # parse rules definition string
        print("::::rules_def:::::\n", rules_def)
        rule_set = rules_parser.parse_string(rules_def, parse_all=True)[0]
        
        template_path = pathlib.Path(__file__).parent.resolve()/'templates'/'rule_estimator_template.py.jinja'
        with open(template_path, mode='rt', encoding='utf8') as f:
            estimator_template = f.read()
        
        render_ctx = {
            'rule_set': rule_set,
            'estimator_name': 'Estimator'
        }
        template = Environment().from_string(estimator_template)
        rendered_estimator = template.render( **render_ctx )
         
        return rendered_estimator
