#!python
"""
CLI script for chest X-ray registration.

Usage:
    cxas_register -i {input} -o {out_dir} [-r {reference}] [options]
"""

import argparse
import logging
import os
import sys
from pathlib import Path
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def parse_arguments() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Register chest X-ray images using landmark-based affine transformation.",
        epilog="Written by Constantin Seibold. If you use this tool, please cite accordingly."
    )

    parser.add_argument(
        "-i", "--input",
        metavar="filepath",
        type=str,
        required=True,
        help="Path to file or directory to be processed."
    )

    parser.add_argument(
        "-o", "--out_dir",
        metavar="directory",
        type=str,
        required=False,
        default=None,
        help="Output directory for registered images. Required unless using --build-reference."
    )

    parser.add_argument(
        "-r", "--reference",
        metavar="filepath",
        type=str,
        default=None,
        help="Path to reference image, directory, or .npz file. "
             "If not provided, uses default reference or computes from input directory."
    )

    parser.add_argument(
        "-g", "--gpus",
        default="cpu",
        help="Select specific GPU/CPU to process the input (default: 'cpu')."
    )

    parser.add_argument(
        "-m", "--model",
        choices=["UNet_ResNet50_default"],
        default="UNet_ResNet50_default",
        help="Model used for inference (default: UNet_ResNet50_default)."
    )

    parser.add_argument(
        "--no-correction",
        action="store_true",
        help="Skip orientation/color correction."
    )

    parser.add_argument(
        "--save-mask",
        action="store_true",
        help="Save registered segmentation masks."
    )

    parser.add_argument(
        "--build-reference",
        action="store_true",
        help="Build reference features from input (no registration performed)."
    )

    parser.add_argument(
        "--reference-out",
        metavar="filepath",
        type=str,
        default=None,
        help="Path to save built reference (.npz). Required with --build-reference."
    )

    return parser.parse_args()


def build_reference_from_input(args):
    """Build reference features from input image or directory."""
    from cxas import CXAS
    from cxas.registration import (
        ReferenceBuilder,
        LandmarkExtractor,
        save_reference
    )
    from cxas.registration.landmarks import compute_average_landmarks

    if args.reference_out is None:
        logging.error("--reference-out is required when using --build-reference")
        sys.exit(1)

    logging.info(f"Building reference from: {args.input}")

    model = CXAS(model_name=args.model, gpus=args.gpus)
    builder = ReferenceBuilder(model)
    landmark_extractor = LandmarkExtractor()

    input_path = Path(args.input)

    if input_path.is_file():
        # Build from single file
        file_dict = model.fileloader.load_file(str(input_path))

        # Get segmentation for landmarks
        import torch
        with torch.no_grad():
            predictions = model.model({
                "data": file_dict["data"],
                "filename": [file_dict["filename"]],
                "file_size": [file_dict["file_size"]]
            })

        segmentation = predictions["segmentation_preds"][0].cpu().numpy()
        landmarks = landmark_extractor.extract(segmentation)

        reference = builder.build_full_reference(
            file_dict["data"],
            landmarks,
            str(input_path)
        )

    elif input_path.is_dir():
        # Build from directory - average features
        from cxas.file_io import get_folder_loader
        import torch
        import numpy as np

        dataloader = get_folder_loader(str(input_path), args.gpus, batch_size=1)
        all_landmarks = []
        all_color_feats = []
        all_rotation_feats = []
        first_file_data = None

        for file_dict in tqdm(dataloader, desc="Processing files"):
            if first_file_data is None:
                first_file_data = file_dict["data"]

            # Get segmentation
            device = file_dict["data"].device
            with torch.no_grad():
                predictions = model.model(file_dict)

            for i in range(len(predictions["filename"])):
                seg = predictions["segmentation_preds"][i].cpu().numpy()
                landmarks = landmark_extractor.extract(seg)
                all_landmarks.append(landmarks)

        # Build orientation reference from first file
        reference = builder.build_orientation_reference(
            first_file_data,
            f"average_from_{len(all_landmarks)}_images"
        )

        # Add average landmarks
        avg_landmarks = compute_average_landmarks(all_landmarks)
        if avg_landmarks is not None and avg_landmarks.valid:
            reference["t4_x"] = np.array(avg_landmarks.t4.x)
            reference["t4_y"] = np.array(avg_landmarks.t4.y)
            reference["t10_x"] = np.array(avg_landmarks.t10.x)
            reference["t10_y"] = np.array(avg_landmarks.t10.y)
            reference["num_samples"] = np.array(len(all_landmarks))
        else:
            logging.warning("Could not compute average landmarks from input directory")

    else:
        logging.error(f"{input_path} is neither a file nor a directory.")
        sys.exit(1)

    # Save reference
    save_reference(reference, args.reference_out)
    logging.info(f"Reference saved to: {args.reference_out}")


def main() -> None:
    """Main entry point for the script."""
    args = parse_arguments()

    # Handle build-reference mode
    if args.build_reference:
        build_reference_from_input(args)
        return

    # Validate that out_dir is provided when not building reference
    if args.out_dir is None:
        logging.error("-o/--out_dir is required unless using --build-reference")
        sys.exit(1)

    from cxas import CXAS
    from cxas.registration import Registrator, load_reference
    from cxas.registration.registrator import save_registration_result
    from cxas.file_io import get_folder_loader

    # Initialize model
    model = CXAS(model_name=args.model, gpus=args.gpus)

    # Determine reference path
    reference_path = args.reference
    if reference_path is not None:
        ref_path = Path(reference_path)
        if ref_path.is_dir() or (ref_path.is_file() and ref_path.suffix != '.npz'):
            # Build reference from image/directory
            logging.info(f"Building reference from: {reference_path}")
            from cxas.registration import ReferenceBuilder, LandmarkExtractor, save_reference
            import tempfile
            import torch

            builder = ReferenceBuilder(model)
            landmark_extractor = LandmarkExtractor()

            if ref_path.is_file():
                file_dict = model.fileloader.load_file(str(ref_path))
                with torch.no_grad():
                    predictions = model.model({
                        "data": file_dict["data"],
                        "filename": [file_dict["filename"]],
                        "file_size": [file_dict["file_size"]]
                    })
                seg = predictions["segmentation_preds"][0].cpu().numpy()
                landmarks = landmark_extractor.extract(seg)
                reference = builder.build_full_reference(
                    file_dict["data"], landmarks, str(ref_path)
                )
            else:
                # Directory - use first image
                from cxas.file_io import FolderDataset
                dataset = FolderDataset(str(ref_path), args.gpus)
                if len(dataset) == 0:
                    logging.error(f"No images found in {ref_path}")
                    sys.exit(1)
                file_dict = dataset[0]
                file_dict["data"] = file_dict["data"].unsqueeze(0) if file_dict["data"].dim() == 3 else file_dict["data"]
                with torch.no_grad():
                    predictions = model.model({
                        "data": file_dict["data"],
                        "filename": [file_dict["filename"]],
                        "file_size": [file_dict["file_size"]]
                    })
                seg = predictions["segmentation_preds"][0].cpu().numpy()
                landmarks = landmark_extractor.extract(seg)
                reference = builder.build_full_reference(
                    file_dict["data"], landmarks, str(ref_path)
                )

            # Save to temp file
            temp_ref = tempfile.NamedTemporaryFile(suffix='.npz', delete=False)
            save_reference(reference, temp_ref.name)
            reference_path = temp_ref.name

    # Initialize registrator
    registrator = Registrator(
        model=model,
        reference_path=reference_path,
        do_correction=not args.no_correction
    )

    input_path = Path(args.input)
    output_dir = Path(args.out_dir)
    os.makedirs(output_dir, exist_ok=True)

    if input_path.is_file():
        logging.info(f"Processing file: {input_path}")

        file_dict = model.fileloader.load_file(str(input_path))
        file_dict["filename"] = file_dict["filename"]
        file_dict["file_size"] = file_dict["file_size"]

        result = registrator.register_single(file_dict, save_mask=args.save_mask)

        save_registration_result(result, str(output_dir), save_mask=args.save_mask)

        if result.success:
            logging.info(f"Registration completed successfully. Output saved to: {output_dir}")
        else:
            logging.warning(f"Registration completed with errors: {result.error}")

    elif input_path.is_dir():
        logging.info(f"Processing directory: {input_path}")

        dataloader = get_folder_loader(str(input_path), args.gpus, batch_size=1)

        success_count = 0
        error_count = 0

        for file_dict in tqdm(dataloader, desc="Registering images"):
            result = registrator.register_single(file_dict, save_mask=args.save_mask)
            save_registration_result(result, str(output_dir), save_mask=args.save_mask)

            if result.success:
                success_count += 1
            else:
                error_count += 1
                logging.warning(f"Failed to register {result.filename}: {result.error}")

        logging.info(
            f"Registration completed. Success: {success_count}, Errors: {error_count}. "
            f"Output saved to: {output_dir}"
        )

    else:
        logging.error(f"{input_path} is neither a file nor a directory.")
        sys.exit(1)


if __name__ == "__main__":
    main()
