#!/usr/bin/env -S casa --nologfile --log2term --nologger -c  # noqa: EXE003
from __future__ import annotations

import shlex
import subprocess as sp
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pathlib import Path
from typing import NamedTuple


def get_target_name(vis_path: Path) -> str:
    tb.open(str(vis_path / "FIELD"))
    name = str(tb.getcol("NAME")[0])
    tb.close()

    name = name.replace(" ", "_")
    if name.startswith("field_"):
        name = name.replace("field_", "")
    if name.endswith("_0"):
        name = name[:-2]
    return name


def quack_data(vis_path: Path) -> Path:
    _ = flagdata(
        vis=vis_path.as_posix(),
        mode="quack",
        quackinterval=10.0,
        quackmode="beg",
        flagbackup=False,
    )
    return vis_path


def aoflagger(vis_path: Path) -> Path:
    cmd = f"aoflagger {vis_path.as_posix()}"
    _ = sp.run(shlex.split(cmd), check=True)
    return vis_path


class DelayResult(NamedTuple):
    vis_path: Path
    gain_path: Path


class DelayOptions(NamedTuple):
    refant: str = "s9-2"
    uvrange: str = ">500m"
    cal_table: Path | None = None


def calculate_delays(
    vis_path: Path,
    target_name: str,
    delay_options: DelayOptions,
) -> DelayResult:
    # calc the delays
    if delay_options.cal_table is not None:
        cal_table = delay_options.cal_table
    else:
        cal_table = vis_path.parent / f"{target_name}.delay.tb"
    _ = gaincal(
        vis=vis_path.as_posix(),
        caltable=cal_table.as_posix(),
        uvrange=delay_options.uvrange,
        antenna="*&",
        solint="inf",
        refant=delay_options.refant,
        refantmode="strict",
        minblperant=2,
        gaintype="K",
    )

    _ = applycal(vis=vis_path.as_posix(), gaintable=[cal_table.as_posix()])
    return DelayResult(
        vis_path=vis_path,
        gain_path=cal_table,
    )


class AverageOptions(NamedTuple):
    chanbin: int = 1
    timebin: str = "0s"


def average(
    vis: Path,
    output_vis: Path,
    average_options: AverageOptions,
) -> Path:
    _ = mstransform(
        vis=vis.as_posix(),
        outputvis=output_vis.as_posix(),
        timeaverage=True,
        chanaverage=True,
        datacolumn="all",
        chanbin=average_options.chanbin,
        timebin=average_options.timebin,
    )
    return output_vis


def pre_process(
    vis_path: Path,
    average_options: AverageOptions,
    delay_options: DelayOptions,
    output_vis: Path | None = None,
) -> Path:
    target_name = get_target_name(vis_path)
    if output_vis is None:
        output_vis = vis_path.parent / f"{target_name}.av.ms"

    quacked_vis = quack_data(vis_path)
    flagged_vis = aoflagger(vis_path=quacked_vis)

    delayed_result = calculate_delays(
        vis_path=flagged_vis,
        target_name=target_name,
        delay_options=delay_options,
    )

    avg_vis = average(
        vis=delayed_result.vis_path,
        output_vis=output_vis,
        average_options=average_options,
    )
    return aoflagger(vis_path=avg_vis)


def get_parser() -> ArgumentParser:
    parser = ArgumentParser(
        description="Get the field name from a MS",
        formatter_class=ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("vis", type=Path, help="Input MeasurementSet.")
    parser.add_argument(
        "--output-vis",
        type=Path,
        help="Output MeasurementSet. If not set will use field name.",
        default=None,
    )
    parser.add_argument(
        "--output-delay-table",
        type=Path,
        help="Output delay table. If not set will use field name.",
        default=None,
    )
    parser.add_argument(
        "--chanbin",
        type=int,
        default=1,
        help="Width (bin) of input channels to average to form an output channel.",
    )
    parser.add_argument(
        "--timebin", type=str, default="0s", help="Bin width for time averaging."
    )
    parser.add_argument(
        "--refant",
        type=str,
        default="s9-2",
        help="Reference antenna name(s) in `gaincal`.",
    )
    parser.add_argument(
        "--uvrange",
        type=str,
        default=">500m",
        help="Select data by baseline length in `gaincal`.",
    )

    return parser


def main() -> None:
    parser = get_parser()
    args = parser.parse_args()

    avg_options = AverageOptions(
        chanbin=args.chanbin,
        timebin=args.timebin,
    )
    delay_options = DelayOptions(
        refant=args.refant,
        uvrange=args.uvrange,
        cal_table=args.output_delay_table,
    )

    _ = pre_process(
        vis_path=args.vis,
        output_vis=args.output_vis,
        average_options=avg_options,
        delay_options=delay_options,
    )


if __name__ == "__main__":
    main()
