#!/usr/bin/env python3
"""
pepper-stream -- real-time event stream from the pepper plane.

Connects to the control plane WebSocket and streams all events
(network_request, navigation_change, screen_appeared, watch_update,
ble_*, log, etc.) with formatted output.

Usage:
  pepper-stream                              # Stream all events
  pepper-stream --events network_request     # Only network events
  pepper-stream --events "network_request,navigation_change"
  pepper-stream --json                       # Raw JSON output (one per line)
  pepper-stream --port 8765                  # Custom port
  pepper-stream --no-color                   # Disable colors
"""

import argparse
import asyncio
import json
import os
import signal
import sys
import time
import uuid
from datetime import datetime

try:
    import websockets
except ImportError:
    print("Error: 'websockets' package required. Install with: pip install websockets", file=sys.stderr)
    sys.exit(1)

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pepper_ios.pepper_common import DEFAULT_HOST, discover_port as _discover_port


def discover_port(simulator=None):
    """Auto-discover Pepper port, falling back to 8765 for CLI use."""
    try:
        return _discover_port(simulator=simulator, fallback=8765)
    except RuntimeError as e:
        print(str(e), file=sys.stderr)
        sys.exit(1)


# --- Colors ---

USE_COLOR = sys.stdout.isatty()


def _c(code, text):
    if USE_COLOR:
        return f"\033[{code}m{text}\033[0m"
    return text


def green(t): return _c("32", t)
def red(t): return _c("31", t)
def yellow(t): return _c("33", t)
def cyan(t): return _c("36", t)
def magenta(t): return _c("35", t)
def blue(t): return _c("34", t)
def dim(t): return _c("2", t)
def bold(t): return _c("1", t)
def white(t): return _c("37", t)


# --- Event formatters ---

# Color by event type category
EVENT_COLORS = {
    "network_request": cyan,
    "navigation_change": yellow,
    "screen_appeared": yellow,
    "state_changed": blue,
    "watch_update": magenta,
    "log": dim,
    "ble_scan_result": green,
    "ble_scan_complete": green,
    "ble_notification": green,
    "ble_disconnect": red,
    "ble_device_connected": green,
    "ble_device_disconnected": red,
    "ble_device_available": green,
}


def format_timestamp():
    return dim(datetime.now().strftime("%H:%M:%S.%f")[:-3])


def format_network(data):
    """Format a network_request event into a compact, readable line."""
    req = data.get("request", {})
    resp = data.get("response", {})
    timing = data.get("timing", {})

    method = req.get("method", "?")
    url = req.get("url", "?")
    status = resp.get("status_code", "---")
    duration = timing.get("duration_ms", "?")
    size = resp.get("original_body_size", resp.get("content_length", 0))

    # Color status code
    if isinstance(status, int):
        if status < 300:
            status_str = green(str(status))
        elif status < 400:
            status_str = yellow(str(status))
        else:
            status_str = red(str(status))
    else:
        status_str = dim(str(status))

    # Shorten URL: strip scheme + common prefixes
    short_url = url
    for prefix in ["https://", "http://"]:
        if short_url.startswith(prefix):
            short_url = short_url[len(prefix):]

    # Format size
    if isinstance(size, (int, float)) and size > 0:
        if size > 1024 * 1024:
            size_str = f"{size / (1024 * 1024):.1f}MB"
        elif size > 1024:
            size_str = f"{size / 1024:.1f}KB"
        else:
            size_str = f"{size}B"
    else:
        size_str = ""

    method_str = bold(method.ljust(4))
    duration_str = dim(f"{duration}ms")
    size_part = f"  {dim(size_str)}" if size_str else ""

    return f"{method_str} {status_str} {short_url}  {duration_str}{size_part}"


def format_navigation(data):
    """Format navigation_change event."""
    frm = data.get("from", "?")
    to = data.get("to", "?")
    return f"{dim(frm)} {bold('→')} {bold(to)}"


def format_screen_appeared(data):
    """Format screen_appeared event."""
    screen = data.get("screen", "?")
    return bold(screen)


def format_watch_update(data):
    """Format watch_update event."""
    wid = data.get("watch_id", "?")
    change = data.get("change", "?")
    current = data.get("current", {})
    label = current.get("label", "")
    return f"{dim(wid)} {bold(change)}" + (f" → {label}" if label else "")


def format_ble(event_type, data):
    """Format BLE events."""
    if event_type == "ble_scan_result":
        name = data.get("name", "?")
        rssi = data.get("rssi", "?")
        return f"{bold(name)} RSSI={rssi}"
    elif event_type == "ble_notification":
        char_uuid = data.get("characteristic", "?")
        length = data.get("length", "?")
        return f"{dim(char_uuid[:12])}… {length}B"
    elif event_type in ("ble_disconnect", "ble_device_disconnected"):
        addr = data.get("address", "?")
        name = data.get("name", addr)
        return bold(name)
    elif event_type in ("ble_device_connected", "ble_device_available"):
        name = data.get("name", "?")
        return bold(name)
    elif event_type == "ble_scan_complete":
        count = data.get("count", "?")
        return f"{count} devices"
    return json.dumps(data, separators=(",", ":"))


def format_log(data):
    """Format log event."""
    level = data.get("level", "")
    msg = data.get("message", str(data))
    cat = data.get("category", "")
    prefix = f"[{cat}] " if cat else ""
    return f"{prefix}{msg}"


def format_event(event_type, data, raw_json=False):
    """Format an event for display."""
    if raw_json:
        return json.dumps({"event": event_type, "data": data}, separators=(",", ":"))

    color_fn = EVENT_COLORS.get(event_type, white)
    tag = color_fn(f"[{event_type}]")
    ts = format_timestamp()

    if event_type == "network_request":
        detail = format_network(data)
    elif event_type == "navigation_change":
        detail = format_navigation(data)
    elif event_type == "screen_appeared":
        detail = format_screen_appeared(data)
    elif event_type == "watch_update":
        detail = format_watch_update(data)
    elif event_type.startswith("ble_"):
        detail = format_ble(event_type, data)
    elif event_type == "log":
        detail = format_log(data)
    else:
        detail = json.dumps(data, separators=(",", ":"))

    return f"{ts} {tag} {detail}"


# --- Counters ---

class Stats:
    def __init__(self):
        self.counts = {}
        self.start_time = time.time()

    def record(self, event_type):
        self.counts[event_type] = self.counts.get(event_type, 0) + 1

    def total(self):
        return sum(self.counts.values())

    def summary(self):
        elapsed = time.time() - self.start_time
        parts = [f"{bold(str(self.total()))} events in {elapsed:.1f}s"]
        for k, v in sorted(self.counts.items(), key=lambda x: -x[1]):
            color_fn = EVENT_COLORS.get(k, white)
            parts.append(f"  {color_fn(k)}: {v}")
        return "\n".join(parts)


# --- Main stream loop ---

async def stream_events(host, port, event_filter, raw_json, max_events=0, duration=0):
    url = f"ws://{host}:{port}"
    stats = Stats()
    deadline = (asyncio.get_running_loop().time() + duration) if duration > 0 else None

    try:
        async with websockets.connect(url, ping_interval=20, ping_timeout=10, compression=None) as ws:
            # Subscribe to specific events if filtered
            if event_filter:
                sub_msg = {
                    "id": str(uuid.uuid4())[:8],
                    "cmd": "subscribe",
                    "params": {"events": event_filter},
                }
                await ws.send(json.dumps(sub_msg))
                # Wait for subscribe response
                raw = await asyncio.wait_for(ws.recv(), timeout=5)
                resp = json.loads(raw)
                if resp.get("status") == "ok":
                    subscribed = resp.get("data", {}).get("subscribed", event_filter)
                    if not raw_json:
                        print(dim(f"Subscribed to: {', '.join(subscribed)}"))
                else:
                    print(red(f"Subscribe failed: {resp}"), file=sys.stderr)

            if not raw_json:
                filter_note = f" (filter: {', '.join(event_filter)})" if event_filter else ""
                print(dim(f"Streaming from {url}{filter_note} — Ctrl+C to stop\n"))

            while True:
                if deadline and asyncio.get_running_loop().time() >= deadline:
                    if not raw_json:
                        print(dim(f"\nDuration limit reached ({duration}s)."))
                    break

                try:
                    raw = await asyncio.wait_for(ws.recv(), timeout=30)
                except asyncio.TimeoutError:
                    # Keep-alive — no events for 30s is fine
                    continue

                data = json.loads(raw)

                # Only process events, skip command responses
                if "event" not in data:
                    continue

                event_type = data["event"]
                event_data = data.get("data", {})
                stats.record(event_type)

                line = format_event(event_type, event_data, raw_json)
                print(line, flush=True)

                if max_events > 0 and stats.total() >= max_events:
                    if not raw_json:
                        print(dim(f"\nEvent limit reached ({max_events})."))
                    break

    except ConnectionRefusedError:
        print(red(f"Cannot connect to {url}"), file=sys.stderr)
        print("Is the target app running with the control plane enabled?", file=sys.stderr)
        print("Try: make ping", file=sys.stderr)
        sys.exit(1)
    except websockets.exceptions.ConnectionClosed:
        if not raw_json:
            print(red("\nConnection closed."), file=sys.stderr)
    except KeyboardInterrupt:
        pass
    finally:
        if not raw_json and stats.total() > 0:
            print(f"\n{stats.summary()}")


def main():
    parser = argparse.ArgumentParser(
        description="Stream real-time events from the pepper plane.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  pepper-stream                                     # All events
  pepper-stream --events network_request            # Only network traffic
  pepper-stream --events network_request,watch_update
  pepper-stream --json                              # NDJSON output
  pepper-stream --json --events network_request     # NDJSON, network only
""",
    )
    parser.add_argument(
        "--host", default=DEFAULT_HOST, help=f"Host (default: {DEFAULT_HOST})"
    )
    parser.add_argument(
        "--port", "-p", type=int, default=None,
        help="Port (default: auto-discover from /tmp/pepper-ports/)",
    )
    parser.add_argument(
        "--simulator", "-s", default=None,
        help="Simulator UDID (reads port from /tmp/pepper-ports/<UDID>.port)",
    )
    parser.add_argument(
        "--events", "-e", default="",
        help="Comma-separated event types to subscribe to (default: all)",
    )
    parser.add_argument(
        "--json", "-j", action="store_true", dest="raw_json",
        help="Output raw NDJSON (one JSON object per line)",
    )
    parser.add_argument(
        "--no-color", action="store_true",
        help="Disable color output",
    )
    parser.add_argument(
        "--limit", "-n", type=int, default=0,
        help="Stop after N events (default: unlimited)",
    )
    parser.add_argument(
        "--duration", "-d", type=int, default=0,
        help="Stop after N seconds (default: unlimited)",
    )

    args = parser.parse_args()

    if args.no_color:
        global USE_COLOR
        USE_COLOR = False

    # Resolve port: explicit --port wins, then --simulator lookup, then auto-discover
    if args.port is None:
        args.port = discover_port(simulator=args.simulator)

    event_filter = [e.strip() for e in args.events.split(",") if e.strip()] if args.events else []

    # Handle SIGINT gracefully
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    try:
        loop.run_until_complete(stream_events(args.host, args.port, event_filter, args.raw_json, max_events=args.limit, duration=args.duration))
    except KeyboardInterrupt:
        pass
    finally:
        loop.close()


if __name__ == "__main__":
    main()
