#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = ["kodelet-sdk"]
# [tool.uv.sources]
# kodelet-sdk = { path = "../..", editable = true }
# ///
from __future__ import annotations

import json
from typing import Any, Literal

from kodelet_sdk import BaseModel, Extension, Field

BASH_POLICY_FILE = "bash-policy.json"
ALLOW_ONCE = "Allow once"
ALLOW_ALWAYS = "Allow and remember this exact command"
DENY_ONCE = "Deny once"
DENY_ALWAYS = "Deny and remember this exact command"
BASH_DECISION_OPTIONS = [ALLOW_ONCE, ALLOW_ALWAYS, DENY_ONCE, DENY_ALWAYS]

ext = Extension(name="workspace", version="0.1.0")


class BashPolicy(BaseModel):
    allowed: list[str] = Field(default_factory=list)
    denied: list[str] = Field(default_factory=list)


class BashCommandDetails(BaseModel):
    command: str
    description: str | None = None
    timeout: int | float | None = None


class AskUserChoiceInput(BaseModel):
    question: str = Field(min_length=1, description="The question to ask the user")
    options: list[str] = Field(
        min_length=2,
        max_length=5,
        description="The options to choose from (2-5 items)",
    )


@ext.tool(
    "ask_user_choice",
    description=(
        "Present user with multiple choices when there are several possible approaches and you "
        "need them to pick one. Use when you have 2-5 concrete options to choose from"
    ),
    input_schema=AskUserChoiceInput,
)
async def ask_user_choice(input: AskUserChoiceInput, ctx):
    choice = await ctx.ui.select(
        {
            "title": input.question,
            "message": "Choose one option.",
            "options": input.options,
            "submitButtonText": "Select",
        }
    )

    if not choice:
        return "User dismissed the question without choosing."

    try:
        index = input.options.index(choice)
    except ValueError:
        return f"User responded with: {choice}"
    return f"User selected option {index + 1}: {choice}"


@ext.on("agent.start")
async def notify_ready(_event, ctx):
    policy = await read_bash_policy(ctx)
    await ctx.ui.notify(
        {
            "title": "Workspace extension ready",
            "message": (
                "Workspace extension started. Remembered bash policy: "
                f"{len(policy.allowed)} allowed, {len(policy.denied)} denied."
            ),
        }
    )


@ext.on("tool.call", priority=100, timeout_in_sec=300)
async def approve_bash_command(event, ctx):
    if event.tool.name != "bash":
        return None

    details = parse_bash_command_details(event.tool.input)
    if details is None:
        return None

    policy = await read_bash_policy(ctx)
    if details.command in policy.denied:
        return {
            "block": {
                "reason": f"Bash command denied by remembered workspace policy: {details.command}"
            }
        }
    if details.command in policy.allowed:
        return None

    choice = await ctx.ui.select(
        {
            "title": "Allow bash command?",
            "message": format_bash_prompt(details, policy),
            "options": BASH_DECISION_OPTIONS,
            "submitButtonText": "Apply",
            "cancelButtonText": "Deny",
        }
    )

    if choice == ALLOW_ONCE:
        return None
    if choice == ALLOW_ALWAYS:
        await remember_bash_decision(ctx, details.command, "allowed", policy)
        return None
    if choice == DENY_ALWAYS:
        await remember_bash_decision(ctx, details.command, "denied", policy)
        return {
            "block": {
                "reason": (
                    "Bash command denied and remembered by workspace policy: "
                    f"{details.command}"
                )
            }
        }

    reason = (
        f"Bash command denied by user: {details.command}"
        if choice == DENY_ONCE
        else f"Bash command denied because no approval was received: {details.command}"
    )
    return {"block": {"reason": reason}}


def parse_bash_command_details(input_value: Any) -> BashCommandDetails | None:
    parsed = parse_json_object(input_value) if isinstance(input_value, str) else input_value
    if not isinstance(parsed, dict):
        return None

    command_value = parsed.get("command")
    command = command_value.strip() if isinstance(command_value, str) else ""
    if not command:
        return None

    description_value = parsed.get("description")
    description = (
        description_value.strip()
        if isinstance(description_value, str) and description_value.strip()
        else None
    )
    timeout_value = parsed.get("timeout")
    timeout = timeout_value if isinstance(timeout_value, int | float) else None
    return BashCommandDetails(command=command, description=description, timeout=timeout)


def format_bash_prompt(details: BashCommandDetails, policy: BashPolicy) -> str:
    lines = ["A bash command is about to run.", "", "Command:", details.command]

    if details.description:
        lines.extend(["", f"Description: {details.description}"])
    if details.timeout is not None:
        lines.append(f"Timeout: {details.timeout}s")

    lines.extend(
        [
            "",
            f"Remembered decisions: {len(policy.allowed)} allowed, {len(policy.denied)} denied.",
        ]
    )
    return "\n".join(lines)


async def read_bash_policy(ctx) -> BashPolicy:
    try:
        return normalize_bash_policy(await ctx.storage.read_json(BASH_POLICY_FILE))
    except Exception as error:  # pragma: no cover - defensive logging path
        ctx.log.warn("failed to read bash policy; using empty policy", {"error": str(error)})
        return BashPolicy()


async def remember_bash_decision(
    ctx,
    command: str,
    decision: Literal["allowed", "denied"],
    current_policy: BashPolicy,
) -> None:
    if decision == "allowed":
        next_policy = BashPolicy(
            allowed=unique_strings([*current_policy.allowed, command]),
            denied=[entry for entry in current_policy.denied if entry != command],
        )
    else:
        next_policy = BashPolicy(
            allowed=[entry for entry in current_policy.allowed if entry != command],
            denied=unique_strings([*current_policy.denied, command]),
        )

    try:
        await ctx.storage.write_json(BASH_POLICY_FILE, next_policy.model_dump())
    except Exception as error:  # pragma: no cover - defensive logging path
        ctx.log.warn(
            "failed to persist bash policy decision",
            {"command": command, "decision": decision, "error": str(error)},
        )


def normalize_bash_policy(value: Any) -> BashPolicy:
    if not isinstance(value, dict):
        return BashPolicy()
    return BashPolicy(
        allowed=unique_strings(value.get("allowed")),
        denied=unique_strings(value.get("denied")),
    )


def unique_strings(value: Any) -> list[str]:
    if not isinstance(value, list):
        return []
    return sorted({entry.strip() for entry in value if isinstance(entry, str) and entry.strip()})


def parse_json_object(value: str) -> Any:
    try:
        return json.loads(value)
    except json.JSONDecodeError:
        return None


if __name__ == "__main__":
    ext.run_sync()
