#!python

import yaml
import os,sys
import argparse
import datetime
#import ast
from copy import deepcopy
import subprocess
from prompt_toolkit import prompt

from desispec.calibfinder import CalibFinder,sm2sp

def safe_eval(s):
    """
    Safely evaluates a string containing a Python literal or returns the original string if evaluation fails.

    Args:
        s (str): The string to evaluate.

    Returns:
        The evaluated Python object if successful, otherwise the original string.
    """
    try:
        return yaml.safe_load(s)
    except (yaml.YAMLError, ValueError):
        return s

def load_yaml(file_path):
    """
    Loads a YAML file and extracts the header comments.

    Args:
        file_path (str): Path to the YAML file.

    Returns:
        tuple: A tuple containing the loaded YAML data and the header comments.
    """
    # save as header the first lines starting woith '#'
    header=""
    with open(file_path, 'r') as file:
        lines=file.readlines()
        for line in lines :
            if line[0]=="#" :
                header += line
            else :
                break

    with open(file_path, 'r') as file:
        return yaml.safe_load(file),header

# Define a custom representer for lists
class CustomRepresenter(yaml.representer.SafeRepresenter):
    def represent_list(self, data, depth=0):
        """
        Represents a list in YAML format with flow style based on depth.

        Args:
            data (list): The list to represent.
            depth (int): The depth level for determining flow style.

        Returns:
            yaml.nodes.SequenceNode: The YAML node representing the list.
        """
        # Apply flow style only for a given depth
        if depth >= 1:
            return self.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=False)
        else:
            return self.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)

# Add the custom representer to the SafeDumper
CustomRepresenter.add_representer(list, CustomRepresenter.represent_list)

# Define a custom dumper that uses the custom representer
class CustomDumper(yaml.SafeDumper):
    pass
CustomDumper.add_representer(list, CustomRepresenter.represent_list)

def save_yaml(data, header, file_path):
    """
    Saves data to a YAML file with a header.

    Args:
        data (dict): The data to save.
        header (str): The header comments to include at the top of the file.
        file_path (str): Path to the output YAML file.
    """
    with open(file_path, 'w') as file:
        file.write(header+"\n")
        yaml.dump_all([data], file, Dumper=CustomDumper,sort_keys=False)

def yes_or_no(question):
    """
    Prompts the user for a yes/no response.

    Args:
        question (str): The question to ask the user.

    Returns:
        str: 'y' for yes, 'n' for no.
    """
    yn = prompt(question.replace("?","").strip()+" (y/n)? ")
    while yn.lower() not in ['y','n'] :
        yn = prompt("Please type 'y' or 'n'? ")
    return yn

def input_date(message) :
    """
    Prompts the user to enter a date in YYYYMMDD format.

    Args:
        message (str): The prompt message to display to the user.

    Returns:
        str: The entered date in YYYYMMDD format.
    """
    while(True) :
        yyyymmdd = safe_eval(prompt(message.strip().replace(":","")+": "))
        message  = "Please use YYYYMMDD format: "
        if type(yyyymmdd) != int :
            continue
        yyyymmdd = int(yyyymmdd)
        yy = yyyymmdd//10000
        if yy<2000 or yy>3000 :
            print(f"Invalid year {yy}")
            continue
        mm = (yyyymmdd-yy*10000)//100
        if mm<1 or mm >12 :
            print(f"Invalid month {mm}")
            continue
        dd = yyyymmdd%100
        if dd<1 or dd>31 :
            print(f"Invalid day {dd}")
            continue
        break
    return yyyymmdd # return integer

def input_keyval() :
    """
    Prompts the user to enter a key-value pair.

    Returns:
        tuple: A tuple containing the key and the value.
    """
    while(True) :
        kv = prompt(f"Enter a key:value or key:[value,'comment'] : ")
        #kv2=kv.split(":")
        i=kv.find(":")
        if i <= 0 :
            print(f"\nERROR Cannot split '{kv}' as key:value\n")
            haderror=True
            continue
        k=kv[:i]
        v=safe_eval(kv[i+1:].strip())
        #print(f"type of {v} is",type(v))

        if k != 'COMMENT' : # more flexible for the comment string
            if type(v) == list :
                if len(v) != 2 :
                    print(f"\nERROR with entry '{kv}', lists must have a length of 2")
                    haderror=True
                    continue
                if type(v[1]) != str :
                    print(f"\nERROR with entry '{kv}', second item in list must be a string")
                    haderror=True
                    continue
                if type(v[0]) not in [ int , float , str ]:
                    print(f"\nERROR with entry '{kv}', first item in list must be a int , float or string")
                    haderror=True
                    continue
            else :
                if type(v) not in [ int , float , str ]:
                    print(f"\nERROR with entry '{kv}', value must be a int , float , string or list")
                    print("and not:")
                    print(type(v))
                    haderror=True
                    continue
                if type(v) == str :
                    if len(v.split(" "))!=1 :
                        print(f"\nERROR with entry '{kv}', value cannot have spaces unless it is a list with '[]'")
                        haderror=True
                        continue
                    if '[' in v or ']' in v :
                        print(f"\nERROR with entry '{kv}', value interpreted as str but contains [,]. Please make sure the string values in the list are between '' so that the entry can be interpreted as a list with a string element.")
                        haderror=True
                        continue
        break

    return k,v

def input_value(key) :
    """
    Prompts the user to enter a value for a given key.

    Args:
        key (str): The key for which the value is being entered.

    Returns:
        The entered value.
    """
    while(True) :
        v = safe_eval(prompt(f"{key}: "))
        if key != 'COMMENT' : # more flexible for the comment string
            if type(v) == list :
                if len(v) != 2 :
                    print(f"\nERROR with entry '{v}', lists must have a length of 2")
                    haderror=True
                    continue
                if type(v[1]) != str :
                    print(f"\nERROR with entry '{v}', second item in list must be a string")
                    haderror=True
                    continue
                if type(v[0]) not in [ int , float , str ]:
                    print(f"\nERROR with entry '{v}', first item in list must be a int , float or string")
                    haderror=True
                    continue
            else :
                if type(v) not in [ int , float , str ]:
                    print(f"\nERROR with entry '{v}', value must be a int , float , string or list")
                    print("and not:")
                    print(type(v))
                    haderror=True
                    continue
                if type(v) == str :
                    if len(v.split(" "))!=1 :
                        print(f"\nERROR with entry '{v}', value cannot have spaces unless it is a list with '[]'")
                        haderror=True
                        continue
                    if '[' in v or ']' in v :
                        print(f"\nERROR with entry '{v}', value interpreted as str but contains [,]. Please make sure the string values in the list are between '' so that the entry can be interpreted as a list with a string element.")
                        haderror=True
                        continue
        break

    return v

def edit_configuration(configuration_data):
    """
    Edits the configuration data interactively.

    Args:
        configuration_data (dict): The configuration data to edit.
    """
    haderror=False
    while True :
        keys_to_edit = []

        if not haderror :
            print("----------------")
            print()
            for key,val in configuration_data.items() :
                print(f'{key}:{val}')
                if val=="EDIT" :
                    keys_to_edit.append(key)
        haderror=False
        print()
        if len(keys_to_edit)> 0 :
            key=keys_to_edit[0]
            if key in ["DATE-OBS-BEGIN","DATE-OBS-END"] :
                configuration_data[key]=input_date(f"{key}: ")
            else :
                configuration_data[key]=input_value(key)
            continue

        akey=prompt('(a)dd or modify entry, (r)emove, (s)ave or (q)quit? ').lower()
        print(akey)
        action_keys=['a','r','s','q']
        while akey not in action_keys :
            akey = prompt(f"Please type a key in {action_keys}: ")
        if akey == "q" :
            sys.exit(0)
        if akey == "s" :
            break
        if akey == "r" :
            ikeys = prompt(f"Enter key(s) to remove:")
            keys=[]
            tmp = ikeys.split(" ")
            for t in tmp :
                tmp2 = t.strip().split(",")
                for t2 in tmp2 :
                   keys.append(t2)
            print(keys)
            for k in keys :
                print("Removing",k)
                configuration_data.pop(k)
        if akey == "a" :
            k,v = input_keyval()
            configuration_data[k]=v


def test_calibration_finder(data,filename) :
    """
    Tests the calibration finder with the provided data and filename.

    Args:
        data (dict): The configuration data.
        filename (str): The path to the calibration file.

    Returns:
        bool: True if the test succeeds, False otherwise.
    """
    test_succeeded = True
    if "DESI_SPECTRO_CALIB" in os.environ.keys() :
        try :

            if "DESI_SPECTRO_DARK" in  os.environ.keys() :
                # do not use DESI_SPECTRO_DARK info which can result
                # in errors unrelated to the edition of DESI_SPECTRO_CALIB
                os.environ.pop("DESI_SPECTRO_DARK")

            basefilename=os.path.basename(filename)
            tmp=basefilename.split(".")[0].split("-")
            if len(tmp)==2 :
                smid=tmp[0]
                arm=tmp[1]
                spid=sm2sp(smid).lower().replace("sp","")
                camera=f"{arm}{spid}"
                expected_filename="{}/spec/{}/{}".format(os.environ["DESI_SPECTRO_CALIB"],smid,basefilename)
                filename1=os.path.realpath(expected_filename)
                filename2=os.path.realpath(filename)
                if filename1 == filename2 :
                    header={}
                    for k in ["DETECTOR","CCDCFG","CCDTMING"] :
                        header[k]=data[k]
                    header["NIGHT"]=data["DATE-OBS-BEGIN"]
                    header["CAMERA"]=camera
                    header["SPECID"]=smid.lower().replace("sm","")
                    print(f"Testing the calibration finder with header={header}")
                    cfinder=CalibFinder([header],filename)
                    # now check some files
                    for k in ["PSF","BIAS","DARK","PIXFLAT","MASK","FIBERFLAT","FLUXCALIB","SKYCORR","SKYGRADPCA","TPCORRPARAM","FIBERCROSSTALK"] :
                        if cfinder.haskey(k) :
                            filename=cfinder.findfile(k)
                            if not os.path.isfile(filename) :
                                raise IOError(f"Missing {k} file {filename}")
                            else :
                                print(f"{k} {filename} OK")
                else :
                    print("(no test of the calibration finder because the expected and edited yaml files are not the same)")
                    print(f" expected: {filename1}")
                    print(f" edited  : {filename2}")

        except Exception as e :
            print("ERROR while testing the calibration finder:",e)
            test_succeeded = False
        if test_succeeded :
            print("OK")
    return test_succeeded

def main():
    """
    Main function to run the DESI spectroscopic calibration configuration editor.
    """
    parser = argparse.ArgumentParser(description="DESI spectroscopic calibration configuration editor")
    parser.add_argument('yaml_file', type = str, help = 'path of yaml file to modify')

    args = parser.parse_args()
    if not os.path.isfile(args.yaml_file) :
        print(f"The file {args.yaml_file} does not exist.")
        return

    data,header = load_yaml(args.yaml_file)

    camids = list(data.keys())
    if len(camids) == 0 :
        print("There are no entries in the YAML file. There needs to be at least one defining the camera id")
        sys.exit(12)
    if len(camids) > 1 :
        camid=prompt(f"Which camid do you want to edit among {camids}? ")
        while camid not in camids :
            camid=prompt(f"Please choose one among {camids}? ")
    else :
        camid=camids[0]
    camid_data=data[camid]

    configuration_names = list(camid_data.keys())
    print(f"Current configuration names for {camid} in the YAML file:")
    print(configuration_names)

    last_config_name = None
    previous_config_name = None
    if yes_or_no("Do you want to create a new configuration?") == "y" :
        config_name=datetime.datetime.now().strftime('V%Y%m%d')
        iterator=1
        while config_name in configuration_names :
            config_name=datetime.datetime.now().strftime('V%Y%m%d')+f"-{iterator}"
            iterator += 1

        while yes_or_no(f"Is '{config_name}' OK?")=="n" :
            while True :
                config_name=prompt("Enter a new config name: ")
                if config_name in configuration_names :
                    print(f"Sorry {config_name} already exists")
                else :
                    break
        print(config_name)
        last_config_name=configuration_names[0]
        previous_config_name=last_config_name
        yn = yes_or_no(f"Do you want to start from a copy of '{previous_config_name}' ?")
        if yn == "n" :
            while True :
                previous_config_name=prompt(f"Choose a previous configuration among {configuration_names}: ")
                if previous_config_name not in configuration_names :
                    print(f"'{previous_config_name}' is not in the list.")
                else :
                    break
        print(f"Starting from a copy of '{previous_config_name}' to '{config_name}'")
        # would like to insert at the beginning, so need a full copy:
        new_data = dict()
        new_data[config_name]=deepcopy(camid_data[previous_config_name])
        for k in ["DATE-OBS-BEGIN","COMMENT"] :
            new_data[config_name][k]="EDIT" # force edit of this entry
        if "DATE-OBS-END" in new_data[config_name].keys() :
            new_data[config_name].pop("DATE-OBS-END")

        for c in camid_data.keys() :
            new_data[c] = camid_data[c]
        data[camid] = new_data
        camid_data  = data[camid]

    else :
        while True :
            config_name=prompt(f"Pick a configuration to modify: ")
            if config_name not in configuration_names :
                print(f"'{config_name}' is not in the list:")
                print(configuration_names)
            else :
                break
        print(f"Will edit '{config_name}'")

    while True :

        edit_configuration(camid_data[config_name])

        if last_config_name is not None :
            if yes_or_no(f"Change DATE-OBS-END of previous config {last_config_name}?") == "y" :
                yyyymmdd = input_date("Enter DATE-OBS-END: ")
                camid_data[last_config_name]["DATE-OBS-END"]=yyyymmdd

        yn = yes_or_no(f"Overwrite '{args.yaml_file}'?")
        if yn == "y" :
            outfilename=args.yaml_file
        else :
            outfilename=prompt(f"Enter output file name: ")

        save_yaml(data, header, outfilename+"-tmp")
        print("Changes:")
        res=subprocess.call(["diff",args.yaml_file,outfilename+"-tmp"])
        print()

        if yes_or_no("OK?") == "y" :
            test_succeeded = test_calibration_finder(camid_data[config_name],outfilename+"-tmp")
            if test_succeeded or ( yes_or_no("Save anyway?") == "y" ):
                os.rename(outfilename+"-tmp",outfilename)
                print(f"Changes saved to {outfilename}")
                return 0

        os.unlink(outfilename+"-tmp")
        print(f"Resume editing")



if __name__ == "__main__":
    main()
