# ================================================
# ARCA — Autonomous Reinforcement Cyber Agent
# Full Codebase Context for LLM Improvement
# Generated on: $(date)
# Version: 0.1.1 (PyPI)
# GitHub: https://github.com/DipayanDasgupta/arca
# ================================================



# ==================== ROOT FILES ====================


# File: pyproject.toml
# ================================================
[project]
name = "arca-agent"
version = "0.1.1"
description = "ARCA — Local RL-powered Autonomous Cyber Pentesting Agent with LangGraph"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "MIT"}
authors = [{name = "Dipayan Dasgupta", email = "deep.dasgupta2006@gmail.com"}]
keywords = ["reinforcement-learning", "cybersecurity", "pentesting", "langgraph", "agentic-ai", "pybind11"]

[project.urls]
Homepage = "https://github.com/DipayanDasgupta/arca"
Repository = "https://github.com/DipayanDasgupta/arca"

[project.optional-dependencies]
dev = ["pytest", "black", "ruff", "mypy"]
cpp = ["pybind11>=2.11"]
viz = ["dash>=2.16"]
all = ["pybind11>=2.11", "dash>=2.16", "ollama>=0.2"]

[project.scripts]
arca = "arca.cli:main"

[build-system]
requires = ["setuptools>=68", "wheel", "pybind11>=2.11"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["."]
include = ["arca*"]

[tool.setuptools.package-data]
arca = ["configs/*.yaml", "data/*.json"]

[tool.black]
line-length = 100

[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-v --tb=short"




# File: README.md
# ================================================


# ARCA — Autonomous Reinforcement Cyber Agent

> **A fully local, pip-installable RL-powered cyber pentesting simulation framework with LangGraph orchestration and optional C++ acceleration.**

[![PyPI version](https://img.shields.io/pypi/v/arca-agent.svg)](https://pypi.org/project/arca-agent/)
[![Python](https://img.shields.io/badge/python-3.10%2B-blue)](https://python.org)
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
[![RL](https://img.shields.io/badge/RL-PPO%20%7C%20A2C%20%7C%20DQN-orange)](https://stable-baselines3.readthedocs.io)
[![LangGraph](https://img.shields.io/badge/Orchestration-LangGraph-purple)](https://langchain-ai.github.io/langgraph)

---

## What is ARCA?

ARCA trains a reinforcement learning agent to autonomously discover and exploit vulnerabilities in simulated computer networks. It combines:

- **Gymnasium-compatible environment** — realistic hosts, subnets, CVEs, and network topology
- **PPO/A2C/DQN via Stable-Baselines3** — policy training with eval callbacks and checkpointing
- **LangGraph multi-agent orchestration** — Analyst → Attacker → Critic → Reflection pipeline with LLM-powered explanations
- **C++ acceleration via pybind11** — BFS reachability, batch exploit simulation, Floyd-Warshall (with pure-Python fallback)
- **FastAPI REST interface** — `/train`, `/audit`, `/reflect`, `/visualize` endpoints
- **Rich visualization suite** — Plotly network graphs, training curves, attack path overlays, vulnerability heatmaps
- **Full CLI** via Typer: `arca train`, `arca serve`, `arca audit`, `arca viz`

Everything runs **100% locally** — no cloud, no data leaves your machine.

---

## Installation

**Install via PyPI (Recommended)**
```bash
pip install arca-agent
```
*(Note: If your system has a C++ compiler like `g++` or `clang`, pip will automatically compile the high-performance C++ extensions during installation. Otherwise, it will gracefully fall back to the pure-Python implementation.)*

**Install from Source (For Development)**
```bash
git clone https://github.com/dipayandasgupta/arca.git
cd arca

# Create virtual environment
python -m venv venv
source venv/bin/activate       # Windows: venv\Scripts\activate

# Install in editable mode
pip install -e .

# Install with explicit C++ dependencies
pip install -e ".[cpp]"

# Install dev dependencies
pip install -e ".[dev]"
```

---

## Quickstart

```python
from arca import ARCAAgent, NetworkEnv, ARCAConfig

# Create environment
env = NetworkEnv.from_preset("small_office")

# Create and train agent
agent = ARCAAgent(env=env)
agent.train(timesteps=50_000)

# Run one episode
result = agent.run_episode(render=True)
print(result.summary())

# LangGraph reflection
agent.enable_langgraph()
report = agent.reflect(env.get_state_dict())
print(report["reflection"])
```

Or via CLI:

```bash
arca train --timesteps 50000 --preset small_office
arca serve                      # starts FastAPI at http://localhost:8000
arca audit --preset enterprise  # one-shot audit report
arca viz --output ./figures     # generate all plots
```

---

## Network Presets

| Preset        | Hosts | Subnets | Vuln Density | Max Steps |
|---------------|-------|---------|--------------|-----------|
| `small_office`  | 8     | 2       | 50%          | 150       |
| `enterprise`    | 25    | 5       | 35%          | 300       |
| `dmz`           | 15    | 3       | 45%          | 200       |
| `iot_network`   | 20    | 4       | 60%          | 250       |

---

## Actions

| Action      | Description                                          |
|-------------|------------------------------------------------------|
| `SCAN`      | Discover a reachable host and its services/vulns     |
| `EXPLOIT`   | Attempt to compromise a discovered host via a CVE    |
| `PIVOT`     | Move attacker's position to a compromised host       |
| `EXFILTRATE`| Extract data value from a compromised host           |

---

## LangGraph Architecture

```text
START → analyst_node → attacker_node → critic_node → reflect_node → END
                              ↑___________________________|
                                   (reflection loop)
```

Each node uses a local LLM (via Ollama, default: `llama3`) for natural-language analysis. Falls back to rule-based logic if Ollama is not running.

---

## C++ Acceleration

The optional `_cpp_sim` module (built via pybind11) provides:

- `compute_reachability(adj, n)` — BFS all-pairs reachability (~10x faster than NetworkX for dense graphs)
- `floyd_warshall(weights, n)` — All-pairs shortest path
- `batch_exploit(hosts, actions, seed)` — Vectorised exploit simulation

Falls back to pure Python automatically if not compiled.

---

## API Endpoints

Once you run `arca serve`:

| Endpoint              | Method | Description                        |
|-----------------------|--------|------------------------------------|
| `/`                   | GET    | Health check + status              |
| `/train`              | POST   | Start a training run               |
| `/audit`              | POST   | Run an audit episode + get report  |
| `/reflect`            | POST   | Run LangGraph reflection on state  |
| `/status`             | GET    | Current training / agent status    |
| `/docs`               | GET    | Auto-generated Swagger UI          |

---

## Project Structure

```text
arca/
├── arca/
│   ├── __init__.py          # public API
│   ├── __version__.py
│   ├── cli.py               # Typer CLI
│   ├── core/
│   │   ├── config.py        # ARCAConfig dataclass
│   │   ├── agent.py         # ARCAAgent (PPO wrapper + LangGraph)
│   │   └── trainer.py       # SB3 training harness
│   ├── sim/
│   │   ├── environment.py   # Gymnasium NetworkEnv
│   │   ├── host.py          # Host dataclass
│   │   ├── action.py        # Action / ActionResult types
│   │   └── network_generator.py
│   ├── agents/
│   │   └── langgraph_orchestrator.py
│   ├── cpp_ext/
│   │   ├── __init__.py      # Python fallback + CPP_AVAILABLE flag
│   │   └── sim_engine.cpp   # pybind11 C++ module
│   ├── viz/
│   │   └── visualizer.py    # Plotly + NetworkX charts
│   └── api/
│       └── server.py        # FastAPI app
├── tests/
│   └── test_arca.py
├── examples/
│   └── quickstart.py
├── pyproject.toml
├── setup.py
└── README.md
```

---

## Disclaimer

ARCA is a **simulation and education tool only**. All attack actions run inside a sandboxed in-memory graph. It does **not** perform any real network scanning, exploitation, or traffic generation. For authorised security testing only.

---

## Author

**Dipayan Dasgupta** — IIT Madras, Civil Engineering  
[GitHub](https://github.com/dipayandasgupta) · [LinkedIn](https://www.linkedin.com/in/dipayan-dasgupta-24a24719b/)



# File: setup.py
# ================================================
"""
setup.py — ARCA build script.

Handles C++ extension (pybind11) with graceful fallback.
Prefer `pip install -e .` which uses pyproject.toml.
For C++ build: `pip install -e ".[cpp]" --no-build-isolation`
"""

from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import sys
import os


class OptionalBuildExt(build_ext):
    """Build C++ extension but don't fail the whole install if it can't compile."""

    def run(self):
        try:
            super().run()
        except Exception as e:
            print(f"\n[ARCA] ⚠ C++ extension build failed: {e}")
            print("[ARCA] Falling back to pure-Python simulation (all functionality still works).\n")

    def build_extension(self, ext):
        try:
            super().build_extension(ext)
            print(f"[ARCA] ✓ C++ extension '{ext.name}' built successfully.")
        except Exception as e:
            print(f"[ARCA] ⚠ Could not build {ext.name}: {e}")
            print("[ARCA] Pure-Python fallback will be used.\n")


def get_ext_modules():
    try:
        import pybind11
        ext = Extension(
            "arca._cpp_sim",
            sources=["arca/cpp_ext/sim_engine.cpp"],
            include_dirs=[pybind11.get_include()],
            language="c++",
            extra_compile_args=["-std=c++17", "-O3", "-march=native", "-fvisibility=hidden"],
        )
        return [ext]
    except ImportError:
        print("[ARCA] pybind11 not found — skipping C++ extension.")
        return []


try:
    with open("README.md", encoding="utf-8") as f:
        long_description = f.read()
except FileNotFoundError:
    long_description = "ARCA — Autonomous Reinforcement Cyber Agent"


setup(
    name="arca-agent",
    version="0.1.0",
    author="Dipayan Dasgupta",
    author_email="ce24b059@smail.iitm.ac.in",
    description="Local RL-powered Autonomous Cyber Agent with LangGraph orchestration",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/dipayandasgupta/arca",
    packages=find_packages(exclude=["tests*", "docs*", "examples*", "scripts*"]),
    ext_modules=get_ext_modules(),
    cmdclass={"build_ext": OptionalBuildExt},
    python_requires=">=3.10",
    install_requires=[
        "numpy>=1.24",
        "gymnasium>=0.29",
        "stable-baselines3>=2.2",
        "torch>=2.0",
        "networkx>=3.0",
        "fastapi>=0.110",
        "uvicorn[standard]>=0.29",
        "pydantic>=2.0",
        "rich>=13.0",
        "typer>=0.12",
        "matplotlib>=3.8",
        "plotly>=5.20",
        "pandas>=2.0",
        "httpx>=0.27",
        "langchain>=0.2",
        "langchain-community>=0.2",
        "langgraph>=0.1",
        "langchain-core>=0.2",
        "pyyaml>=6.0",
    ],
    extras_require={
        "dev": ["pytest>=7.0", "pytest-cov", "black", "ruff", "mypy"],
        "cpp": ["pybind11>=2.11"],
        "viz": ["dash>=2.16", "dash-cytoscape>=1.0"],
        "llm": ["ollama>=0.2"],
        "all": ["pybind11>=2.11", "dash>=2.16", "ollama>=0.2"],
    },
    entry_points={
        "console_scripts": [
            "arca=arca.cli:main",
        ],
    },
    classifiers=[
        "Development Status :: 3 - Alpha",
        "Intended Audience :: Science/Research",
        "Intended Audience :: Developers",
        "Topic :: Security",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: Python :: 3.12",
        "Operating System :: POSIX :: Linux",
        "Operating System :: Microsoft :: Windows",
        "Operating System :: MacOS",
    ],
    keywords="reinforcement-learning cybersecurity pentesting autonomous-agent langgraph pybind11",
    include_package_data=True,
    package_data={"arca": ["configs/*.yaml", "data/*.json"]},
    zip_safe=False,
)



# File: .gitignore
# ================================================
# -----------------------------
# ARCA Specific
# -----------------------------
arca_outputs/

# -----------------------------
# Compiled Python / C++
# -----------------------------
__pycache__/
*.py[cod]
*$py.class
*.so
*.o
build/

# -----------------------------
# Packaging / Distribution
# -----------------------------
dist/
*.egg-info/
*.egg

# -----------------------------
# Environments
# -----------------------------
venv/
.venv/
env/

# -----------------------------
# OS / IDE
# -----------------------------
.DS_Store
.vscode/
.idea/




# ==================== ARCA PACKAGE ====================


# File: arca/__init__.py
# ================================================
"""
ARCA — Autonomous Reinforcement Cyber Agent
============================================
A fully local RL-powered autonomous pentesting agent with:
  • Custom network simulation environment (Gymnasium-compatible)
  • PPO-based reinforcement learning (Stable-Baselines3)
  • LangGraph multi-agent orchestration with LLM critic & reflection
  • C++ accelerated simulation via pybind11 (optional)
  • FastAPI REST interface
  • Rich visualization suite (Plotly + NetworkX)
  • Full CLI via Typer

Quickstart
----------
    from arca import ARCAAgent, NetworkEnv, ARCAConfig

    env = NetworkEnv.from_preset("small_office")
    agent = ARCAAgent(env=env)
    agent.train(timesteps=50_000)
    result = agent.run_episode()
    print(result.summary())
"""

"""ARCA — Autonomous Reinforcement Cyber Agent"""

from arca.__version__ import __version__
from arca.core.config import ARCAConfig
from arca.sim.environment import NetworkEnv
from arca.core.agent import ARCAAgent
from arca.core.trainer import ARCATrainer
from arca.viz.visualizer import ARCAVisualizer

__all__ = [
    "__version__",
    "ARCAConfig",
    "NetworkEnv",
    "ARCAAgent",
    "ARCATrainer",
    "ARCAVisualizer",
]



# File: arca/__version__.py
# ================================================
__version__ = "0.1.1"



# File: arca/cli.py
# ================================================
"""
arca.cli
========
Typer-based CLI for ARCA.

Commands:
  arca train   — Train a PPO agent on a network preset
  arca serve   — Start the FastAPI REST server
  arca audit   — Run a quick audit and print report
  arca viz     — Generate all visualizations from a saved log
  arca info    — Show version and config info
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Optional

import typer
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich import print as rprint

app = typer.Typer(
    name="arca",
    help="ARCA — Autonomous Reinforcement Cyber Agent CLI",
    add_completion=False,
    no_args_is_help=True,
)
console = Console()

# ──────────────────────────────────────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────────────────────────────────────

@app.command()
def train(
    timesteps: int = typer.Option(50_000, "--timesteps", "-t", help="Total training timesteps"),
    preset: str = typer.Option("small_office", "--preset", "-p", help="Network preset: small_office | enterprise | dmz | iot_network"),
    algo: str = typer.Option("PPO", "--algo", "-a", help="RL algorithm: PPO | A2C | DQN"),
    save_path: Optional[str] = typer.Option(None, "--save", "-s", help="Path to save trained model"),
    no_progress: bool = typer.Option(False, "--no-progress", help="Disable progress bar"),
    verbose: int = typer.Option(1, "--verbose", "-v", help="Verbosity (0=quiet, 1=normal)"),
):
    """Train a PPO (or A2C/DQN) agent on a simulated network environment."""
    try:
        from arca.core.config import ARCAConfig
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
    except ImportError as e:
        console.print(f"[red]Import error: {e}[/red]")
        console.print("[yellow]Run: pip install -e .[/yellow]")
        raise typer.Exit(1)

    console.print(Panel.fit(
        f"[bold cyan]ARCA Training[/bold cyan]\n"
        f"Preset: [green]{preset}[/green]  |  Algo: [green]{algo}[/green]  |  Steps: [green]{timesteps:,}[/green]",
        border_style="cyan",
    ))

    cfg = ARCAConfig.default()
    cfg.env.preset = preset
    cfg.rl.algorithm = algo
    cfg.rl.total_timesteps = timesteps
    cfg.verbose = verbose
    cfg.ensure_dirs()

    env = NetworkEnv.from_preset(preset, cfg=cfg)
    agent = ARCAAgent(env=env, cfg=cfg)

    console.print(f"[dim]Hosts: {cfg.env.num_hosts}  Subnets: {cfg.env.num_subnets}  Obs shape: {env.observation_space.shape}[/dim]")

    agent.train(timesteps=timesteps, progress_bar=not no_progress)

    path = agent.save(save_path)
    console.print(f"\n[bold green]✓ Training complete![/bold green]  Model saved → [cyan]{path}[/cyan]")

    # Quick eval
    console.print("\n[dim]Running 3 evaluation episodes...[/dim]")
    for i in range(3):
        info = agent.run_episode()
        console.print(f"  Episode {i+1}: {info.summary()}")


# ──────────────────────────────────────────────────────────────────────────────
# SERVE
# ──────────────────────────────────────────────────────────────────────────────

@app.command()
def serve(
    host: str = typer.Option("0.0.0.0", "--host", help="Bind host"),
    port: int = typer.Option(8000, "--port", "-p", help="Port"),
    reload: bool = typer.Option(False, "--reload", help="Auto-reload on code changes"),
):
    """Start the ARCA FastAPI REST server."""
    try:
        import uvicorn
    except ImportError:
        console.print("[red]uvicorn not installed. Run: pip install uvicorn[standard][/red]")
        raise typer.Exit(1)

    console.print(Panel.fit(
        f"[bold cyan]ARCA API Server[/bold cyan]\n"
        f"Listening on [green]http://{host}:{port}[/green]\n"
        f"Docs: [green]http://localhost:{port}/docs[/green]",
        border_style="cyan",
    ))

    try:
        uvicorn.run(
            "arca.api.server:app",
            host=host,
            port=port,
            reload=reload,
            log_level="info",
        )
    except ImportError:
        console.print("[yellow]API server module not found. Creating minimal server...[/yellow]")
        _run_minimal_server(host, port)


def _run_minimal_server(host: str, port: int):
    """Fallback minimal FastAPI server."""
    try:
        from fastapi import FastAPI
        import uvicorn

        mini_app = FastAPI(title="ARCA API", version="0.1.0")

        @mini_app.get("/")
        def root():
            return {"status": "ok", "message": "ARCA API running", "version": "0.1.0"}

        @mini_app.get("/health")
        def health():
            return {"status": "healthy"}

        uvicorn.run(mini_app, host=host, port=port)
    except Exception as e:
        console.print(f"[red]Could not start server: {e}[/red]")


# ──────────────────────────────────────────────────────────────────────────────
# AUDIT
# ──────────────────────────────────────────────────────────────────────────────

@app.command()
def audit(
    preset: str = typer.Option("small_office", "--preset", "-p", help="Network preset"),
    model_path: Optional[str] = typer.Option(None, "--model", "-m", help="Path to trained model (.zip)"),
    timesteps: int = typer.Option(20_000, "--timesteps", "-t", help="Quick-train timesteps if no model provided"),
    output: Optional[str] = typer.Option(None, "--output", "-o", help="Save report to JSON file"),
    langgraph: bool = typer.Option(False, "--langgraph", "-lg", help="Enable LangGraph LLM reflection"),
):
    """Run a one-shot security audit and print a natural-language report."""
    try:
        from arca.core.config import ARCAConfig
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
    except ImportError as e:
        console.print(f"[red]{e}[/red]")
        raise typer.Exit(1)

    console.print(Panel.fit(
        f"[bold red]ARCA Security Audit[/bold red]\n"
        f"Target preset: [yellow]{preset}[/yellow]",
        border_style="red",
    ))

    cfg = ARCAConfig.default()
    cfg.env.preset = preset
    cfg.ensure_dirs()

    env = NetworkEnv.from_preset(preset, cfg=cfg)
    agent = ARCAAgent(env=env, cfg=cfg)

    if model_path:
        console.print(f"[dim]Loading model from {model_path}...[/dim]")
        agent.load(model_path)
    else:
        console.print(f"[dim]No model provided — quick-training for {timesteps:,} steps...[/dim]")
        agent.train(timesteps=timesteps, progress_bar=True)

    console.print("\n[bold]Running audit episode...[/bold]")
    episode_info = agent.run_episode(render=False)

    # Build report
    report = {
        "preset": preset,
        "total_reward": round(episode_info.total_reward, 2),
        "steps": episode_info.steps,
        "hosts_compromised": episode_info.hosts_compromised,
        "hosts_discovered": episode_info.hosts_discovered,
        "total_hosts": cfg.env.num_hosts,
        "goal_reached": episode_info.goal_reached,
        "attack_path": episode_info.attack_path,
        "summary": episode_info.summary(),
    }

    if langgraph:
        console.print("[dim]Running LangGraph reflection...[/dim]")
        agent.enable_langgraph()
        reflection = agent.reflect(env.get_state_dict())
        report["llm_analysis"] = reflection.get("reflection", "N/A")
        report["llm_plan"] = reflection.get("plan", "N/A")

    # Print report
    _print_audit_report(report, env)

    if output:
        Path(output).write_text(json.dumps(report, indent=2))
        console.print(f"\n[green]Report saved → {output}[/green]")


def _print_audit_report(report: dict, env):
    console.print()
    table = Table(title="ARCA Audit Report", border_style="red", show_header=True)
    table.add_column("Metric", style="cyan", width=30)
    table.add_column("Value", style="white")

    table.add_row("Preset", report["preset"])
    table.add_row("Hosts Compromised", f"{report['hosts_compromised']} / {report['total_hosts']}")
    table.add_row("Hosts Discovered", str(report["hosts_discovered"]))
    table.add_row("Goal Reached", "✅ YES" if report["goal_reached"] else "❌ NO")
    table.add_row("Steps Taken", str(report["steps"]))
    table.add_row("Total Reward", str(report["total_reward"]))

    console.print(table)

    if report.get("attack_path"):
        console.print("\n[bold red]Attack Path:[/bold red]")
        for i, step in enumerate(report["attack_path"], 1):
            console.print(f"  {i}. {step}")

    if report.get("llm_analysis"):
        console.print(Panel(
            report["llm_analysis"],
            title="[bold purple]LLM Analysis[/bold purple]",
            border_style="purple",
        ))


# ──────────────────────────────────────────────────────────────────────────────
# VIZ
# ──────────────────────────────────────────────────────────────────────────────

@app.command()
def viz(
    preset: str = typer.Option("small_office", "--preset", "-p", help="Network preset to visualize"),
    output: str = typer.Option("arca_outputs/figures", "--output", "-o", help="Output directory for HTML figures"),
    show: bool = typer.Option(False, "--show", help="Open figures in browser"),
):
    """Generate network topology, vulnerability heatmap, and training curve visualizations."""
    try:
        from arca.sim.environment import NetworkEnv
        from arca.core.config import ARCAConfig
        from arca.viz.visualizer import ARCAVisualizer
    except ImportError as e:
        console.print(f"[red]{e}[/red]")
        raise typer.Exit(1)

    console.print(Panel.fit(
        f"[bold blue]ARCA Visualizer[/bold blue]\n"
        f"Preset: [green]{preset}[/green]  |  Output: [green]{output}[/green]",
        border_style="blue",
    ))

    cfg = ARCAConfig.default()
    env = NetworkEnv.from_preset(preset, cfg=cfg)
    env.reset()

    viz_engine = ARCAVisualizer(output_dir=output)

    console.print("[dim]Generating network topology...[/dim]")
    viz_engine.plot_network(env.get_network_graph(), env.get_hosts(), save=True, show=show)

    console.print("[dim]Generating vulnerability heatmap...[/dim]")
    viz_engine.plot_vuln_heatmap(env.get_hosts(), save=True, show=show)

    console.print("[dim]Generating sample training curves...[/dim]")
    import random
    n = 50
    log_data = {
        "episodes": list(range(n)),
        "rewards": [random.gauss(10 * (1 + i / n), 4) for i in range(n)],
        "compromised": [random.randint(1, 6) for _ in range(n)],
        "path_lengths": [random.randint(1, 8) for _ in range(n)],
        "success_rates": [min(1.0, 0.2 + 0.5 * i / n + random.gauss(0, 0.05)) for i in range(n)],
    }
    viz_engine.plot_training_curves(log_data, save=True, show=show)

    console.print(f"\n[bold green]✓ Figures saved to [cyan]{output}/[/cyan][/bold green]")
    console.print("  • network_topology.html")
    console.print("  • vuln_heatmap.html")
    console.print("  • training_curves.html")


# ──────────────────────────────────────────────────────────────────────────────
# INFO
# ──────────────────────────────────────────────────────────────────────────────

@app.command()
def info():
    """Show ARCA version, config defaults, and system info."""
    from arca.__version__ import __version__

    try:
        from arca.cpp_ext import CPP_AVAILABLE
    except Exception:
        CPP_AVAILABLE = False

    try:
        import torch
        torch_ver = torch.__version__
        cuda = torch.cuda.is_available()
    except ImportError:
        torch_ver = "not installed"
        cuda = False

    try:
        import stable_baselines3
        sb3_ver = stable_baselines3.__version__
    except ImportError:
        sb3_ver = "not installed"

    console.print(Panel.fit(
        f"[bold cyan]ARCA[/bold cyan] v{__version__} — Autonomous Reinforcement Cyber Agent\n\n"
        f"[dim]C++ backend:    [/dim]{'[green]✓ available[/green]' if CPP_AVAILABLE else '[yellow]✗ pure-Python fallback[/yellow]'}\n"
        f"[dim]PyTorch:        [/dim][green]{torch_ver}[/green]\n"
        f"[dim]CUDA:           [/dim]{'[green]✓[/green]' if cuda else '[dim]✗ CPU only[/dim]'}\n"
        f"[dim]SB3:            [/dim][green]{sb3_ver}[/green]\n\n"
        f"[dim]GitHub: [/dim][cyan]https://github.com/dipayandasgupta/arca[/cyan]",
        border_style="cyan",
    ))


# ──────────────────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────────────────

def main():
    app()


if __name__ == "__main__":
    main()



# ------------------- core -------------------


# File: arca/core/agent.py
# ================================================
"""
arca.core.agent
~~~~~~~~~~~~~~~
ARCAAgent — the main interface. Combines RL policy (SB3/PPO) with
optional LangGraph multi-agent orchestration for explainable decisions.
"""

from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Optional

from arca.core.config import ARCAConfig
from arca.sim.environment import NetworkEnv, EpisodeInfo


class ARCAAgent:
    """
    High-level agent interface.

    Usage::

        env = NetworkEnv.from_preset("small_office")
        agent = ARCAAgent(env=env)
        agent.train(timesteps=50_000)
        result = agent.run_episode(render=True)
    """

    def __init__(
        self,
        env: Optional[NetworkEnv] = None,
        cfg: Optional[ARCAConfig] = None,
        model_path: Optional[str] = None,
    ):
        self.cfg = cfg or ARCAConfig()
        self.env = env or NetworkEnv(cfg=self.cfg)
        self.cfg.ensure_dirs()

        self._model = None
        self._trainer = None
        self._langgraph = None

        if model_path:
            self.load(model_path)

    # ------------------------------------------------------------------
    # Training
    # ------------------------------------------------------------------

    def train(
        self,
        timesteps: Optional[int] = None,
        callback=None,
        progress_bar: bool = True,
    ) -> "ARCAAgent":
        from arca.core.trainer import ARCATrainer

        self._trainer = ARCATrainer(cfg=self.cfg, env=self.env)
        self._model = self._trainer.train(
            timesteps=timesteps or self.cfg.rl.total_timesteps,
            callback=callback,
            progress_bar=progress_bar,
        )
        return self

    # ------------------------------------------------------------------
    # Inference
    # ------------------------------------------------------------------

    def run_episode(
        self,
        render: bool = False,
        use_langgraph: bool = False,
        deterministic: bool = True,
    ) -> EpisodeInfo:
        if self._model is None:
            raise RuntimeError("Agent not trained. Call agent.train() first or load a model.")

        obs, info = self.env.reset()
        done = False
        episode_info = None

        while not done:
            action, _ = self._model.predict(obs, deterministic=deterministic)
            obs, reward, terminated, truncated, info = self.env.step(int(action))
            done = terminated or truncated

            if render:
                print(self.env.render())

            if use_langgraph and self._langgraph:
                state = self.env.get_state_dict()
                self._langgraph.step(state)

            episode_info = info.get("episode_info") or self.env.episode_info

        return episode_info or self.env.episode_info

    def predict(self, obs, deterministic: bool = True):
        if self._model is None:
            raise RuntimeError("No model loaded.")
        return self._model.predict(obs, deterministic=deterministic)

    # ------------------------------------------------------------------
    # LangGraph
    # ------------------------------------------------------------------

    def enable_langgraph(self) -> "ARCAAgent":
        from arca.agents.langgraph_orchestrator import ARCAOrchestrator
        self._langgraph = ARCAOrchestrator(cfg=self.cfg)
        return self

    def reflect(self, state: dict) -> dict:
        if self._langgraph is None:
            self.enable_langgraph()
        return self._langgraph.reflect(state)

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self, path: Optional[str] = None) -> str:
        if self._model is None:
            raise RuntimeError("No model to save.")
        save_path = path or str(Path(self.cfg.model_dir) / "arca_model")
        self._model.save(save_path)
        return save_path

    def load(self, path: str) -> "ARCAAgent":
        try:
            from stable_baselines3 import PPO, A2C, DQN
            algo_map = {"PPO": PPO, "A2C": A2C, "DQN": DQN}
            AlgoCls = algo_map.get(self.cfg.rl.algorithm, PPO)
            self._model = AlgoCls.load(path, env=self.env)
        except Exception as e:
            raise RuntimeError(f"Failed to load model from {path}: {e}")
        return self

    # ------------------------------------------------------------------
    # Info
    # ------------------------------------------------------------------

    def __repr__(self) -> str:
        trained = self._model is not None
        return (
            f"ARCAAgent(algo={self.cfg.rl.algorithm}, "
            f"env={self.cfg.env.preset}, "
            f"trained={trained})"
        )



# File: arca/core/config.py
# ================================================
"""
arca.core.config
~~~~~~~~~~~~~~~~
Central configuration for ARCA. All subsystems read from ARCAConfig.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional

import yaml


@dataclass
class EnvConfig:
    """Network simulation environment settings."""
    preset: str = "small_office"          # "small_office" | "enterprise" | "custom"
    num_hosts: int = 10
    num_subnets: int = 3
    vulnerability_density: float = 0.4    # fraction of hosts with vulns
    use_cpp_backend: bool = True           # use C++ sim engine if available
    max_steps: int = 200
    reward_goal: float = 100.0
    reward_step: float = -0.5
    reward_discovery: float = 5.0
    reward_exploit: float = 20.0


@dataclass
class RLConfig:
    """PPO / SB3 training settings."""
    algorithm: Literal["PPO", "A2C", "DQN"] = "PPO"
    policy: str = "MlpPolicy"
    learning_rate: float = 3e-4
    n_steps: int = 2048
    batch_size: int = 64
    n_epochs: int = 10
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_range: float = 0.2
    ent_coef: float = 0.01
    total_timesteps: int = 100_000
    eval_freq: int = 10_000
    n_eval_episodes: int = 5
    device: str = "auto"                  # "auto" | "cpu" | "cuda" | "mps"
    tensorboard_log: Optional[str] = None


@dataclass
class LLMConfig:
    """LangGraph / LLM critic and reflection settings."""
    enabled: bool = True
    provider: Literal["ollama", "openai", "anthropic"] = "ollama"
    model: str = "llama3"                 # ollama model name or API model id
    base_url: str = "http://localhost:11434"
    temperature: float = 0.2
    max_tokens: int = 512
    reflection_interval: int = 10        # run reflection every N episodes
    critic_enabled: bool = True
    reflection_enabled: bool = True


@dataclass
class APIConfig:
    """FastAPI server settings."""
    host: str = "0.0.0.0"
    port: int = 8000
    reload: bool = False
    log_level: str = "info"
    cors_origins: list[str] = field(default_factory=lambda: ["*"])


@dataclass
class VizConfig:
    """Visualization suite settings."""
    enabled: bool = True
    backend: Literal["plotly", "matplotlib"] = "plotly"
    dashboard_port: int = 8050
    live_update_interval: int = 2        # seconds
    save_figures: bool = True
    output_dir: str = "arca_outputs/figures"


@dataclass
class ARCAConfig:
    """Master configuration container for ARCA."""
    env: EnvConfig = field(default_factory=EnvConfig)
    rl: RLConfig = field(default_factory=RLConfig)
    llm: LLMConfig = field(default_factory=LLMConfig)
    api: APIConfig = field(default_factory=APIConfig)
    viz: VizConfig = field(default_factory=VizConfig)
    model_dir: str = "arca_outputs/models"
    log_dir: str = "arca_outputs/logs"
    seed: int = 42
    verbose: int = 1

    # ------------------------------------------------------------------
    # Factory helpers
    # ------------------------------------------------------------------

    @classmethod
    def default(cls) -> "ARCAConfig":
        return cls()

    @classmethod
    def from_yaml(cls, path: str | Path) -> "ARCAConfig":
        with open(path) as f:
            data = yaml.safe_load(f)
        cfg = cls()
        _apply_dict(cfg, data or {})
        return cfg

    def to_yaml(self, path: str | Path) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w") as f:
            yaml.dump(_to_dict(self), f, default_flow_style=False)

    def ensure_dirs(self) -> None:
        for d in [self.model_dir, self.log_dir, self.viz.output_dir]:
            Path(d).mkdir(parents=True, exist_ok=True)


# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

def _to_dict(obj) -> dict:
    if hasattr(obj, "__dataclass_fields__"):
        return {k: _to_dict(getattr(obj, k)) for k in obj.__dataclass_fields__}
    return obj


def _apply_dict(cfg, data: dict) -> None:
    for k, v in data.items():
        if hasattr(cfg, k):
            attr = getattr(cfg, k)
            if hasattr(attr, "__dataclass_fields__") and isinstance(v, dict):
                _apply_dict(attr, v)
            else:
                setattr(cfg, k, v)



# File: arca/core/__init__.py
# ================================================
"""arca.core — Configuration, Agent, and Trainer."""

from arca.core.config import ARCAConfig, EnvConfig, RLConfig, LLMConfig, APIConfig, VizConfig
from arca.core.agent import ARCAAgent
from arca.core.trainer import ARCATrainer

__all__ = [
    "ARCAConfig", "EnvConfig", "RLConfig", "LLMConfig", "APIConfig", "VizConfig",
    "ARCAAgent", "ARCATrainer",
]



# File: arca/core/trainer.py
# ================================================
"""
arca.core.trainer
~~~~~~~~~~~~~~~~~
ARCATrainer wraps Stable-Baselines3 training with rich logging,
eval callbacks, and TensorBoard support.
"""

from __future__ import annotations

import time
from pathlib import Path
from typing import Optional

from arca.core.config import ARCAConfig
from arca.sim.environment import NetworkEnv


class ARCATrainer:
    def __init__(self, cfg: ARCAConfig, env: Optional[NetworkEnv] = None):
        self.cfg = cfg
        self.env = env or NetworkEnv(cfg=cfg)
        self._model = None

    def train(
        self,
        timesteps: int,
        callback=None,
        progress_bar: bool = True,
    ):
        try:
            from stable_baselines3 import PPO, A2C, DQN
            from stable_baselines3.common.callbacks import (
                EvalCallback, CheckpointCallback, CallbackList
            )
            from stable_baselines3.common.monitor import Monitor
        except ImportError as e:
            raise ImportError(f"stable-baselines3 not installed: {e}")

        rl = self.cfg.rl
        algo_map = {"PPO": PPO, "A2C": A2C, "DQN": DQN}
        AlgoCls = algo_map.get(rl.algorithm, PPO)

        monitored_env = Monitor(self.env)
        eval_env = Monitor(NetworkEnv(cfg=self.cfg))

        tb_log = self.cfg.rl.tensorboard_log
        if tb_log:
            Path(tb_log).mkdir(parents=True, exist_ok=True)

        kwargs = dict(
            policy=rl.policy,
            env=monitored_env,
            learning_rate=rl.learning_rate,
            gamma=rl.gamma,
            verbose=self.cfg.verbose,
            seed=self.cfg.seed,
            device=rl.device,
            tensorboard_log=tb_log,
        )
        if rl.algorithm == "PPO":
            kwargs.update(
                n_steps=rl.n_steps,
                batch_size=rl.batch_size,
                n_epochs=rl.n_epochs,
                gae_lambda=rl.gae_lambda,
                clip_range=rl.clip_range,
                ent_coef=rl.ent_coef,
            )

        self._model = AlgoCls(**kwargs)

        callbacks = []
        eval_cb = EvalCallback(
            eval_env,
            best_model_save_path=self.cfg.model_dir,
            log_path=self.cfg.log_dir,
            eval_freq=rl.eval_freq,
            n_eval_episodes=rl.n_eval_episodes,
            deterministic=True,
            verbose=0,
        )
        ckpt_cb = CheckpointCallback(
            save_freq=rl.eval_freq,
            save_path=self.cfg.model_dir,
            name_prefix="arca_ckpt",
            verbose=0,
        )
        callbacks.extend([eval_cb, ckpt_cb])
        if callback:
            callbacks.append(callback)

        cb_list = CallbackList(callbacks)

        start = time.time()
        self._model.learn(
            total_timesteps=timesteps,
            callback=cb_list,
            progress_bar=progress_bar,
        )
        elapsed = time.time() - start
        if self.cfg.verbose:
            print(f"\n[ARCA] Training complete in {elapsed:.1f}s ({timesteps} steps)")

        return self._model

    @property
    def model(self):
        return self._model



# ------------------- sim -------------------


# File: arca/sim/action.py
# ================================================
"""arca.sim.action — Action and ActionResult types."""

from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any


class ActionType(IntEnum):
    SCAN = 0
    EXPLOIT = 1
    PIVOT = 2
    EXFILTRATE = 3


@dataclass
class Action:
    action_type: ActionType
    target_host: int
    exploit_id: int
    source_host: int = 0

    def to_dict(self) -> dict[str, Any]:
        return {
            "type": self.action_type.name,
            "source": self.source_host,
            "target": self.target_host,
            "exploit_id": self.exploit_id,
        }

    def __str__(self) -> str:
        return f"{self.action_type.name}({self.source_host}→{self.target_host}, exploit={self.exploit_id})"


@dataclass
class ActionResult:
    success: bool
    message: str = ""
    discovered_hosts: list[int] = field(default_factory=list)
    compromised_host: int | None = None
    data_exfiltrated: float = 0.0

    def to_dict(self) -> dict[str, Any]:
        return {
            "success": self.success,
            "message": self.message,
            "discovered_hosts": self.discovered_hosts,
            "compromised_host": self.compromised_host,
            "data_exfiltrated": round(self.data_exfiltrated, 2),
        }



# File: arca/sim/environment.py
# ================================================
"""
arca.sim.environment
~~~~~~~~~~~~~~~~~~~~
Gymnasium-compatible network penetration testing environment.

Nodes  = hosts (id, subnet, os, vulnerabilities, compromised)
Edges  = reachability between hosts
Actions = (action_type, target_host, exploit_id)
Observation = flat feature vector of known network state
"""

from __future__ import annotations

import random
from dataclasses import dataclass, field
from typing import Any, Optional

import gymnasium as gym
import networkx as nx
import numpy as np
from gymnasium import spaces

from arca.core.config import ARCAConfig, EnvConfig
from arca.sim.network_generator import NetworkGenerator
from arca.sim.host import Host, HostStatus
from arca.sim.action import Action, ActionType, ActionResult

# Try to import C++ backend
try:
    from arca._cpp_sim import SimEngine as _CppSimEngine  # type: ignore
    _CPP_AVAILABLE = True
except ImportError:
    _CppSimEngine = None
    _CPP_AVAILABLE = False


PRESETS = {
    "small_office": EnvConfig(num_hosts=8, num_subnets=2, vulnerability_density=0.5, max_steps=150),
    "enterprise": EnvConfig(num_hosts=25, num_subnets=5, vulnerability_density=0.35, max_steps=300),
    "dmz": EnvConfig(num_hosts=15, num_subnets=3, vulnerability_density=0.45, max_steps=200),
    "iot_network": EnvConfig(num_hosts=20, num_subnets=4, vulnerability_density=0.6, max_steps=250),
}


@dataclass
class EpisodeInfo:
    total_reward: float = 0.0
    steps: int = 0
    hosts_compromised: int = 0
    hosts_discovered: int = 0
    goal_reached: bool = False
    attack_path: list[str] = field(default_factory=list)
    action_log: list[dict] = field(default_factory=list)

    def summary(self) -> str:
        return (
            f"EpisodeInfo(reward={self.total_reward:.2f}, steps={self.steps}, "
            f"compromised={self.hosts_compromised}/{self.hosts_discovered} discovered, "
            f"goal={'✓' if self.goal_reached else '✗'})"
        )


class NetworkEnv(gym.Env):
    """
    Cyber pentesting simulation environment.

    Observation space: flat vector encoding known host states.
    Action space: Discrete — (action_type × num_hosts × num_exploits).
    """

    metadata = {"render_modes": ["human", "rgb_array", "ansi"]}

    # Feature dims per host: [discovered, compromised, os_enc(4), subnet_id, vuln_count, service_count]
    _HOST_FEATURES = 9
    _NUM_EXPLOITS = 5

    def __init__(self, cfg: Optional[ARCAConfig] = None, env_cfg: Optional[EnvConfig] = None):
        super().__init__()
        self.cfg = cfg or ARCAConfig()
        self.env_cfg = env_cfg or self.cfg.env
        self._rng = random.Random(self.cfg.seed if cfg else 42)
        self._np_rng = np.random.default_rng(self.cfg.seed if cfg else 42)

        # Use C++ backend if available and configured
        self._use_cpp = self.env_cfg.use_cpp_backend and _CPP_AVAILABLE

        self._generator = NetworkGenerator(self.env_cfg, self._rng)

        # Will be populated in reset()
        self.graph: nx.DiGraph = nx.DiGraph()
        self.hosts: dict[int, Host] = {}
        self._attacker_node: int = 0
        self._episode_info: EpisodeInfo = EpisodeInfo()
        self._step_count: int = 0

        n = self.env_cfg.num_hosts
        e = self._NUM_EXPLOITS
        num_action_types = len(ActionType)

        # Action: flat index = action_type * (n * e) + host * e + exploit
        self.action_space = spaces.Discrete(num_action_types * n * e)
        self.observation_space = spaces.Box(
            low=0.0,
            high=1.0,
            shape=(n * self._HOST_FEATURES,),
            dtype=np.float32,
        )

    # ------------------------------------------------------------------
    # Factory
    # ------------------------------------------------------------------

    @classmethod
    def from_preset(cls, preset: str, cfg: Optional[ARCAConfig] = None) -> "NetworkEnv":
        if preset not in PRESETS:
            raise ValueError(f"Unknown preset '{preset}'. Available: {list(PRESETS)}")
        base_cfg = cfg or ARCAConfig()
        base_cfg.env = PRESETS[preset]
        return cls(cfg=base_cfg)

    @classmethod
    def from_config(cls, cfg: ARCAConfig) -> "NetworkEnv":
        return cls(cfg=cfg)

    # ------------------------------------------------------------------
    # Gym interface
    # ------------------------------------------------------------------

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self._rng = random.Random(seed)
            self._np_rng = np.random.default_rng(seed)

        self.graph, self.hosts = self._generator.generate()
        self._attacker_node = self._generator.attacker_node
        self._step_count = 0
        self._episode_info = EpisodeInfo()

        # Mark attacker start as discovered & compromised
        self.hosts[self._attacker_node].status = HostStatus.COMPROMISED
        self.hosts[self._attacker_node].discovered = True
        self._episode_info.hosts_compromised = 1
        self._episode_info.hosts_discovered = 1

        obs = self._get_obs()
        info = {"attacker_node": self._attacker_node, "num_hosts": len(self.hosts)}
        return obs, info

    def step(self, action: int):
        n = self.env_cfg.num_hosts
        e = self._NUM_EXPLOITS
        action_type_idx = action // (n * e)
        remainder = action % (n * e)
        target_host_idx = remainder // e
        exploit_idx = remainder % e

        action_type = ActionType(action_type_idx % len(ActionType))
        target_host = target_host_idx % n

        act = Action(
            action_type=action_type,
            target_host=target_host,
            exploit_id=exploit_idx,
            source_host=self._attacker_node,
        )

        result = self._execute_action(act)
        reward = self._compute_reward(result)
        self._step_count += 1
        self._episode_info.total_reward += reward
        self._episode_info.steps = self._step_count
        self._episode_info.action_log.append({
            "step": self._step_count,
            "action": act.to_dict(),
            "result": result.to_dict(),
            "reward": reward,
        })

        terminated = self._check_goal()
        truncated = self._step_count >= self.env_cfg.max_steps
        self._episode_info.goal_reached = terminated

        obs = self._get_obs()
        info = {
            "action_result": result.to_dict(),
            "episode_info": self._episode_info if (terminated or truncated) else None,
        }
        return obs, reward, terminated, truncated, info

    def render(self, mode="ansi"):
        lines = ["=" * 50, "ARCA Network State", "=" * 50]
        for hid, host in self.hosts.items():
            lines.append(
                f"  Host {hid:02d} [{host.subnet}] {host.os:<10} "
                f"{'🔴 PWNED' if host.status == HostStatus.COMPROMISED else '🟡 SEEN' if host.discovered else '⬛ UNKNOWN'}"
                f"  vulns={len(host.vulnerabilities)}"
            )
        lines.append(f"Step {self._step_count}/{self.env_cfg.max_steps}  "
                     f"Reward={self._episode_info.total_reward:.1f}")
        return "\n".join(lines)

    # ------------------------------------------------------------------
    # Internal mechanics
    # ------------------------------------------------------------------

    def _execute_action(self, act: Action) -> ActionResult:
        target = self.hosts.get(act.target_host)
        if target is None:
            return ActionResult(success=False, message="Invalid target")

        if act.action_type == ActionType.SCAN:
            return self._do_scan(act, target)
        elif act.action_type == ActionType.EXPLOIT:
            return self._do_exploit(act, target)
        elif act.action_type == ActionType.PIVOT:
            return self._do_pivot(act, target)
        elif act.action_type == ActionType.EXFILTRATE:
            return self._do_exfiltrate(act, target)
        else:
            return ActionResult(success=False, message="Unknown action")

    def _do_scan(self, act: Action, target: Host) -> ActionResult:
        # Check reachability
        if not self._is_reachable(act.source_host, act.target_host):
            return ActionResult(success=False, message="Host unreachable")
        was_new = not target.discovered
        target.discovered = True
        if was_new:
            self._episode_info.hosts_discovered += 1
        return ActionResult(
            success=True,
            discovered_hosts=[act.target_host] if was_new else [],
            message=f"Scanned host {act.target_host}: {target.os}, {len(target.vulnerabilities)} vulns",
        )

    def _do_exploit(self, act: Action, target: Host) -> ActionResult:
        if not target.discovered:
            return ActionResult(success=False, message="Host not discovered yet")
        if not self._is_reachable(act.source_host, act.target_host):
            return ActionResult(success=False, message="Host unreachable")
        if target.status == HostStatus.COMPROMISED:
            return ActionResult(success=False, message="Already compromised")
        # Check if exploit matches a vulnerability
        if act.exploit_id < len(target.vulnerabilities):
            vuln = target.vulnerabilities[act.exploit_id]
            success_prob = vuln.get("exploit_prob", 0.6)
            if self._np_rng.random() < success_prob:
                target.status = HostStatus.COMPROMISED
                self._attacker_node = act.target_host  # pivot point
                self._episode_info.hosts_compromised += 1
                self._episode_info.attack_path.append(
                    f"{act.source_host}→{act.target_host}(CVE:{vuln.get('cve','?')})"
                )
                return ActionResult(
                    success=True,
                    compromised_host=act.target_host,
                    message=f"Exploited {target.os} via {vuln.get('name','unknown')}",
                )
        return ActionResult(success=False, message="Exploit failed or no matching vuln")

    def _do_pivot(self, act: Action, target: Host) -> ActionResult:
        if target.status != HostStatus.COMPROMISED:
            return ActionResult(success=False, message="Cannot pivot to non-compromised host")
        self._attacker_node = act.target_host
        return ActionResult(success=True, message=f"Pivoted to host {act.target_host}")

    def _do_exfiltrate(self, act: Action, target: Host) -> ActionResult:
        if target.status != HostStatus.COMPROMISED:
            return ActionResult(success=False, message="Cannot exfiltrate from non-compromised host")
        data_value = target.data_value
        return ActionResult(success=True, data_exfiltrated=data_value,
                            message=f"Exfiltrated {data_value:.1f} units from host {act.target_host}")

    def _is_reachable(self, src: int, dst: int) -> bool:
        if src == dst:
            return True
        try:
            return nx.has_path(self.graph, src, dst)
        except nx.NetworkXError:
            return False

    def _compute_reward(self, result: ActionResult) -> float:
        if not result.success:
            return self.env_cfg.reward_step
        r = 0.0
        if result.discovered_hosts:
            r += self.env_cfg.reward_discovery * len(result.discovered_hosts)
        if result.compromised_host is not None:
            r += self.env_cfg.reward_exploit
        if result.data_exfiltrated > 0:
            r += result.data_exfiltrated * 2.0
        r += self.env_cfg.reward_step
        return r

    def _check_goal(self) -> bool:
        n_compromised = sum(1 for h in self.hosts.values() if h.status == HostStatus.COMPROMISED)
        return n_compromised >= max(3, len(self.hosts) // 2)

    def _get_obs(self) -> np.ndarray:
        n = self.env_cfg.num_hosts
        obs = np.zeros(n * self._HOST_FEATURES, dtype=np.float32)
        os_map = {"Windows": 0, "Linux": 1, "macOS": 2, "IoT": 3}
        for i in range(n):
            host = self.hosts.get(i)
            base = i * self._HOST_FEATURES
            if host is None:
                continue
            obs[base + 0] = float(host.discovered)
            obs[base + 1] = float(host.status == HostStatus.COMPROMISED)
            os_idx = os_map.get(host.os, 0)
            obs[base + 2 + os_idx] = 1.0
            obs[base + 6] = host.subnet / max(self.env_cfg.num_subnets, 1)
            obs[base + 7] = len(host.vulnerabilities) / 10.0
            obs[base + 8] = len(host.services) / 10.0
        return obs

    # ------------------------------------------------------------------
    # Utility / introspection
    # ------------------------------------------------------------------

    @property
    def episode_info(self) -> EpisodeInfo:
        return self._episode_info

    def get_network_graph(self) -> nx.DiGraph:
        return self.graph

    def get_hosts(self) -> dict[int, Host]:
        return self.hosts

    def get_state_dict(self) -> dict:
        return {
            "step": self._step_count,
            "attacker_node": self._attacker_node,
            "hosts": {hid: h.to_dict() for hid, h in self.hosts.items()},
            "episode_info": {
                "total_reward": self._episode_info.total_reward,
                "hosts_compromised": self._episode_info.hosts_compromised,
                "hosts_discovered": self._episode_info.hosts_discovered,
                "attack_path": self._episode_info.attack_path,
            },
        }



# File: arca/sim/host.py
# ================================================
"""arca.sim.host — Host node in the network simulation."""

from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any


class HostStatus(IntEnum):
    UNKNOWN = 0
    DISCOVERED = 1
    COMPROMISED = 2


@dataclass
class Host:
    id: int
    subnet: int
    os: str
    ip: str
    services: list[str] = field(default_factory=list)
    vulnerabilities: list[dict] = field(default_factory=list)
    status: HostStatus = HostStatus.UNKNOWN
    discovered: bool = False
    data_value: float = 0.0
    is_critical: bool = False
    firewall: bool = False

    def to_dict(self) -> dict[str, Any]:
        return {
            "id": self.id,
            "subnet": self.subnet,
            "os": self.os,
            "ip": self.ip,
            "services": self.services,
            "vulnerabilities": [v.get("name", "?") for v in self.vulnerabilities],
            "status": self.status.name,
            "discovered": self.discovered,
            "data_value": round(self.data_value, 2),
            "is_critical": self.is_critical,
        }



# File: arca/sim/__init__.py
# ================================================
from arca.sim.environment import NetworkEnv, EpisodeInfo
from arca.sim.host import Host, HostStatus
from arca.sim.action import Action, ActionType, ActionResult

__all__ = ["NetworkEnv", "EpisodeInfo", "Host", "HostStatus", "Action", "ActionType", "ActionResult"]



# File: arca/sim/network_generator.py
# ================================================
"""arca.sim.network_generator — Procedural network topology generator."""

from __future__ import annotations

import random
from typing import Tuple

import networkx as nx

from arca.sim.host import Host, HostStatus
from arca.core.config import EnvConfig

# Vulnerability DB (simplified)
VULN_DB = [
    {"name": "EternalBlue", "cve": "CVE-2017-0144", "exploit_prob": 0.75, "os": "Windows"},
    {"name": "Log4Shell", "cve": "CVE-2021-44228", "exploit_prob": 0.8, "os": "Linux"},
    {"name": "ProxyLogon", "cve": "CVE-2021-26855", "exploit_prob": 0.7, "os": "Windows"},
    {"name": "Shellshock", "cve": "CVE-2014-6271", "exploit_prob": 0.65, "os": "Linux"},
    {"name": "Heartbleed", "cve": "CVE-2014-0160", "exploit_prob": 0.6, "os": "Linux"},
    {"name": "BlueKeep", "cve": "CVE-2019-0708", "exploit_prob": 0.7, "os": "Windows"},
    {"name": "PrintNightmare", "cve": "CVE-2021-34527", "exploit_prob": 0.65, "os": "Windows"},
    {"name": "Dirty COW", "cve": "CVE-2016-5195", "exploit_prob": 0.55, "os": "Linux"},
    {"name": "IoT Default Creds", "cve": "CVE-2020-8958", "exploit_prob": 0.9, "os": "IoT"},
    {"name": "macOS TCC Bypass", "cve": "CVE-2023-41990", "exploit_prob": 0.5, "os": "macOS"},
]

OS_BY_SUBNET = {
    0: ["Windows", "Linux"],  # DMZ
    1: ["Windows", "Linux", "macOS"],  # Corp
    2: ["Linux", "IoT"],  # OT/IoT
    3: ["Windows"],  # AD
    4: ["Linux"],  # Servers
}

SERVICES = {
    "Windows": ["SMB", "RDP", "WinRM", "IIS", "MSSQL"],
    "Linux": ["SSH", "HTTP", "HTTPS", "FTP", "NFS", "PostgreSQL"],
    "macOS": ["SSH", "HTTP", "AFP"],
    "IoT": ["Telnet", "HTTP", "MQTT"],
}


class NetworkGenerator:
    def __init__(self, cfg: EnvConfig, rng: random.Random):
        self.cfg = cfg
        self.rng = rng
        self.attacker_node: int = 0

    def generate(self) -> Tuple[nx.DiGraph, dict[int, Host]]:
        n = self.cfg.num_hosts
        s = self.cfg.num_subnets
        g = nx.DiGraph()
        hosts: dict[int, Host] = {}

        # Create hosts
        for i in range(n):
            subnet = i % s
            os_choices = OS_BY_SUBNET.get(subnet % 5, ["Linux"])
            os = self.rng.choice(os_choices)
            svc_pool = SERVICES.get(os, ["SSH"])
            services = self.rng.sample(svc_pool, k=min(self.rng.randint(1, 3), len(svc_pool)))

            # Assign vulnerabilities
            vulns = []
            if self.rng.random() < self.cfg.vulnerability_density:
                os_vulns = [v for v in VULN_DB if v["os"] == os or v["os"] == "Linux"]
                n_vulns = self.rng.randint(1, min(3, len(os_vulns)))
                vulns = self.rng.sample(os_vulns, k=n_vulns)

            ip = f"10.{subnet}.{self.rng.randint(1, 254)}.{i+1}"
            hosts[i] = Host(
                id=i,
                subnet=subnet,
                os=os,
                ip=ip,
                services=services,
                vulnerabilities=vulns,
                data_value=self.rng.uniform(1.0, 15.0),
                is_critical=(i == n - 1),  # last host is the crown jewel
                firewall=(subnet == 0),
            )
            g.add_node(i, **hosts[i].to_dict())

        # Create edges (reachability)
        # Within-subnet: full mesh
        for a in range(n):
            for b in range(n):
                if a != b and hosts[a].subnet == hosts[b].subnet:
                    g.add_edge(a, b)

        # Cross-subnet: limited edges (gateway links)
        for i in range(n):
            for j in range(n):
                if hosts[i].subnet != hosts[j].subnet:
                    if abs(hosts[i].subnet - hosts[j].subnet) == 1:
                        if self.rng.random() < 0.3:
                            g.add_edge(i, j)

        # Attacker starts at a random host in subnet 0 (internet-facing)
        subnet0_hosts = [i for i, h in hosts.items() if h.subnet == 0]
        self.attacker_node = self.rng.choice(subnet0_hosts) if subnet0_hosts else 0

        return g, hosts



# ------------------- agents -------------------


# File: arca/agents/__init__.py
# ================================================
from arca.agents.langgraph_orchestrator import ARCAOrchestrator, ARCAGraphState

__all__ = ["ARCAOrchestrator", "ARCAGraphState"]



# File: arca/agents/langgraph_orchestrator.py
# ================================================
"""
arca.agents.langgraph_orchestrator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LangGraph-based multi-agent system for explainable RL decision-making.

Graph topology:
  ┌──────────────┐
  │  State Input │
  └──────┬───────┘
         ▼
  ┌──────────────┐    ┌────────────────┐
  │  Analyst     │───▶│  Critic        │
  │  (describes  │    │  (evaluates    │
  │   state)     │    │   RL action)   │
  └──────────────┘    └───────┬────────┘
                              ▼
                     ┌────────────────┐
                     │  Reflector     │
                     │  (learns from  │
                     │   mistakes)    │
                     └───────┬────────┘
                             ▼
                     ┌────────────────┐
                     │  Planner       │
                     │  (suggests     │
                     │   next steps)  │
                     └────────────────┘
"""

from __future__ import annotations

import json
from typing import Any, TypedDict, Optional

from arca.core.config import ARCAConfig


class ARCAGraphState(TypedDict):
    network_state: dict
    analyst_output: str
    critic_output: str
    reflection: str
    plan: str
    episode_history: list[dict]


class ARCAOrchestrator:
    """
    LangGraph-powered orchestrator with LLM nodes.
    Falls back to rule-based heuristics if LLM is unavailable.
    """

    def __init__(self, cfg: Optional[ARCAConfig] = None):
        self.cfg = cfg or ARCAConfig()
        self._llm = None
        self._graph = None
        self._memory: list[dict] = []
        self._llm_available = False

        if self.cfg.llm.enabled:
            self._init_llm()
            if self._llm_available:
                self._build_graph()

    def _init_llm(self) -> None:
        try:
            from langchain_community.llms import Ollama  # type: ignore
            self._llm = Ollama(
                model=self.cfg.llm.model,
                base_url=self.cfg.llm.base_url,
                temperature=self.cfg.llm.temperature,
            )
            # Quick connectivity test
            self._llm.invoke("ping")
            self._llm_available = True
            print(f"[ARCA] LLM connected: {self.cfg.llm.model} @ {self.cfg.llm.base_url}")
        except Exception as e:
            print(f"[ARCA] LLM unavailable ({e}). Using rule-based fallback.")
            self._llm_available = False

    def _build_graph(self) -> None:
        try:
            from langgraph.graph import StateGraph, END  # type: ignore

            graph = StateGraph(ARCAGraphState)
            graph.add_node("analyst", self._analyst_node)
            graph.add_node("critic", self._critic_node)
            graph.add_node("reflector", self._reflector_node)
            graph.add_node("planner", self._planner_node)

            graph.set_entry_point("analyst")
            graph.add_edge("analyst", "critic")
            graph.add_edge("critic", "reflector")
            graph.add_edge("reflector", "planner")
            graph.add_edge("planner", END)

            self._graph = graph.compile()
        except Exception as e:
            print(f"[ARCA] LangGraph build failed: {e}. Using sequential fallback.")
            self._graph = None

    # ------------------------------------------------------------------
    # Node implementations
    # ------------------------------------------------------------------

    def _analyst_node(self, state: ARCAGraphState) -> ARCAGraphState:
        ns = state["network_state"]
        prompt = f"""You are a cybersecurity analyst examining a network penetration test.

Current network state:
- Step: {ns.get('step', 0)}
- Attacker position: Host {ns.get('attacker_node', 0)}
- Hosts discovered: {sum(1 for h in ns.get('hosts', {}).values() if h.get('discovered'))}
- Hosts compromised: {sum(1 for h in ns.get('hosts', {}).values() if h.get('status') == 'COMPROMISED')}
- Total hosts: {len(ns.get('hosts', {}))}
- Attack path so far: {ns.get('episode_info', {}).get('attack_path', [])}

Briefly describe the current attack situation in 2-3 sentences."""

        output = self._invoke_llm(prompt) or self._rule_based_analyst(ns)
        state["analyst_output"] = output
        return state

    def _critic_node(self, state: ARCAGraphState) -> ARCAGraphState:
        ns = state["network_state"]
        analyst = state.get("analyst_output", "")
        prompt = f"""You are a cybersecurity critic reviewing an RL agent's decisions.

Situation: {analyst}

Recent reward: {ns.get('episode_info', {}).get('total_reward', 0):.2f}

Critically evaluate:
1. Is the agent making good use of discovered hosts?
2. Is it being efficient in its attack path?
3. What mistakes is it making?

Be concise (2-3 sentences)."""

        output = self._invoke_llm(prompt) or self._rule_based_critic(ns)
        state["critic_output"] = output
        return state

    def _reflector_node(self, state: ARCAGraphState) -> ARCAGraphState:
        critic = state.get("critic_output", "")
        history = state.get("episode_history", [])
        prompt = f"""You are a reflection module for an RL agent learning pentesting.

Critic feedback: {critic}
Past reflections: {history[-3:] if history else 'None'}

What should the agent learn from this? What patterns should it avoid or repeat?
Be specific and actionable (1-2 sentences)."""

        output = self._invoke_llm(prompt) or "Agent should prioritize scanning before exploiting."
        state["reflection"] = output
        return state

    def _planner_node(self, state: ARCAGraphState) -> ARCAGraphState:
        ns = state["network_state"]
        reflection = state.get("reflection", "")
        prompt = f"""You are a penetration testing planner.

Reflection: {reflection}
Current position: Host {ns.get('attacker_node', 0)}
Compromised hosts: {[hid for hid, h in ns.get('hosts', {}).items() if h.get('status') == 'COMPROMISED']}

Suggest the next 3 high-priority actions for the RL agent (SCAN/EXPLOIT/PIVOT/EXFILTRATE).
Format: action: target_host reason"""

        output = self._invoke_llm(prompt) or self._rule_based_plan(ns)
        state["plan"] = output
        return state

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def step(self, network_state: dict) -> dict:
        initial_state: ARCAGraphState = {
            "network_state": network_state,
            "analyst_output": "",
            "critic_output": "",
            "reflection": "",
            "plan": "",
            "episode_history": self._memory[-5:],
        }

        if self._graph:
            result = self._graph.invoke(initial_state)
        else:
            # Sequential fallback
            result = self._analyst_node(initial_state)
            result = self._critic_node(result)
            result = self._reflector_node(result)
            result = self._planner_node(result)

        self._memory.append({
            "step": network_state.get("step"),
            "reflection": result.get("reflection"),
            "plan": result.get("plan"),
        })

        return result

    def reflect(self, state: dict) -> dict:
        return self.step(state)

    def get_memory(self) -> list[dict]:
        return self._memory

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _invoke_llm(self, prompt: str) -> Optional[str]:
        if not self._llm_available or self._llm is None:
            return None
        try:
            return self._llm.invoke(prompt)
        except Exception:
            return None

    def _rule_based_analyst(self, ns: dict) -> str:
        comp = sum(1 for h in ns.get("hosts", {}).values() if h.get("status") == "COMPROMISED")
        disc = sum(1 for h in ns.get("hosts", {}).values() if h.get("discovered"))
        total = len(ns.get("hosts", {}))
        return (
            f"Agent has compromised {comp}/{total} hosts and discovered {disc}/{total}. "
            f"Currently operating from host {ns.get('attacker_node', 0)}. "
            f"Progress: {comp/max(total,1)*100:.0f}% of objective achieved."
        )

    def _rule_based_critic(self, ns: dict) -> str:
        comp = sum(1 for h in ns.get("hosts", {}).values() if h.get("status") == "COMPROMISED")
        disc = sum(1 for h in ns.get("hosts", {}).values() if h.get("discovered"))
        if disc == 0:
            return "Agent should scan more aggressively to discover hosts before exploiting."
        ratio = comp / max(disc, 1)
        if ratio < 0.3:
            return "Low exploit ratio. Agent discovers hosts but fails to compromise them efficiently."
        return "Agent is making reasonable progress. Consider pivoting to reach deeper subnets."

    def _rule_based_plan(self, ns: dict) -> str:
        hosts = ns.get("hosts", {})
        undiscovered = [hid for hid, h in hosts.items() if not h.get("discovered")]
        exploitable = [
            hid for hid, h in hosts.items()
            if h.get("discovered") and h.get("status") != "COMPROMISED"
        ]
        plan_parts = []
        if undiscovered:
            plan_parts.append(f"SCAN: target host {undiscovered[0]} — discover new hosts")
        if exploitable:
            plan_parts.append(f"EXPLOIT: target host {exploitable[0]} — compromise vulnerable host")
        plan_parts.append("PIVOT: move to highest-value compromised host for better reach")
        return "\n".join(plan_parts)



# ------------------- viz -------------------


# File: arca/viz/__init__.py
# ================================================
"""arca.viz — Visualization suite (Plotly + NetworkX + Matplotlib fallback)."""

from arca.viz.visualizer import ARCAVisualizer

__all__ = ["ARCAVisualizer"]



# File: arca/viz/visualizer.py
# ================================================
"""
arca.viz.visualizer
~~~~~~~~~~~~~~~~~~~
Rich visualization suite for ARCA using Plotly and NetworkX.

Plots:
  - Network topology graph (with host status coloring)
  - Attack path overlay
  - Training reward curve
  - Exploit success heatmap
  - Host vulnerability radar
  - Episode statistics
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Optional, Any

try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False

try:
    import networkx as nx
    NX_AVAILABLE = True
except ImportError:
    NX_AVAILABLE = False

try:
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    MPL_AVAILABLE = True
except ImportError:
    MPL_AVAILABLE = False


class ARCAVisualizer:
    """Static visualization methods for ARCA network states and training metrics."""

    def __init__(self, output_dir: str = "arca_outputs/figures"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # Network Graph
    # ------------------------------------------------------------------

    def plot_network(
        self,
        graph,
        hosts: dict,
        title: str = "ARCA Network Topology",
        save: bool = True,
        show: bool = False,
    ):
        if not PLOTLY_AVAILABLE or not NX_AVAILABLE:
            return self._mpl_network(graph, hosts, title, save)

        pos = nx.spring_layout(graph, seed=42)

        # Edge traces
        edge_x, edge_y = [], []
        for u, v in graph.edges():
            x0, y0 = pos[u]
            x1, y1 = pos[v]
            edge_x += [x0, x1, None]
            edge_y += [y0, y1, None]

        edge_trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=0.8, color="#334155"),
            hoverinfo="none",
            mode="lines",
        )

        # Node traces by status
        color_map = {
            "UNKNOWN": "#1e293b",
            "DISCOVERED": "#f59e0b",
            "COMPROMISED": "#ef4444",
        }
        icon_map = {
            "UNKNOWN": "❓",
            "DISCOVERED": "🔍",
            "COMPROMISED": "💀",
        }

        node_x, node_y, node_colors, node_text, node_hover = [], [], [], [], []
        for node, host in hosts.items():
            if node not in pos:
                continue
            x, y = pos[node]
            status = host.status.name if hasattr(host.status, "name") else str(host.status)
            node_x.append(x)
            node_y.append(y)
            node_colors.append(color_map.get(status, "#64748b"))
            node_text.append(str(node))
            node_hover.append(
                f"Host {node}<br>"
                f"OS: {host.os}<br>"
                f"IP: {host.ip}<br>"
                f"Status: {status}<br>"
                f"Subnet: {host.subnet}<br>"
                f"Vulns: {len(host.vulnerabilities)}<br>"
                f"Services: {', '.join(host.services)}<br>"
                f"Critical: {'⭐' if host.is_critical else 'No'}"
            )

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode="markers+text",
            text=node_text,
            textposition="top center",
            hovertext=node_hover,
            hoverinfo="text",
            marker=dict(
                size=20,
                color=node_colors,
                line=dict(width=2, color="#94a3b8"),
                symbol="circle",
            ),
        )

        fig = go.Figure(
            data=[edge_trace, node_trace],
            layout=go.Layout(
                title=dict(text=title, font=dict(color="#f1f5f9", size=16)),
                paper_bgcolor="#0f172a",
                plot_bgcolor="#0f172a",
                showlegend=False,
                hovermode="closest",
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                font=dict(color="#f1f5f9"),
                margin=dict(l=20, r=20, t=50, b=20),
                annotations=[
                    dict(text="⬛ Unknown  🟡 Discovered  🔴 Compromised",
                         x=0.5, y=-0.05, xref="paper", yref="paper",
                         showarrow=False, font=dict(color="#94a3b8", size=11))
                ],
            ),
        )

        if save:
            path = self.output_dir / "network_topology.html"
            fig.write_html(str(path))
            print(f"[ARCA] Network graph saved → {path}")
        if show:
            fig.show()
        return fig

    # ------------------------------------------------------------------
    # Training Curves
    # ------------------------------------------------------------------

    def plot_training_curves(
        self,
        log_data: dict,
        save: bool = True,
        show: bool = False,
    ):
        if not PLOTLY_AVAILABLE:
            return None

        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=[
                "Episode Reward", "Hosts Compromised / Episode",
                "Attack Path Length", "Exploit Success Rate"
            ],
            vertical_spacing=0.15,
        )

        episodes = log_data.get("episodes", list(range(len(log_data.get("rewards", [])))))
        rewards = log_data.get("rewards", [])
        compromised = log_data.get("compromised", [])
        path_lengths = log_data.get("path_lengths", [])
        success_rates = log_data.get("success_rates", [])

        # Smooth helper
        def smooth(data, w=5):
            import numpy as np
            if len(data) < w:
                return data
            kernel = np.ones(w) / w
            return np.convolve(data, kernel, mode="valid").tolist()

        color = "#22d3ee"
        smooth_color = "#f472b6"

        if rewards:
            fig.add_trace(go.Scatter(x=episodes[:len(rewards)], y=rewards,
                                     mode="lines", name="Reward", line=dict(color=color, width=1),
                                     opacity=0.5), row=1, col=1)
            s = smooth(rewards)
            fig.add_trace(go.Scatter(x=episodes[len(rewards)-len(s):len(rewards)], y=s,
                                     mode="lines", name="Smoothed", line=dict(color=smooth_color, width=2)),
                          row=1, col=1)

        if compromised:
            fig.add_trace(go.Scatter(x=episodes[:len(compromised)], y=compromised,
                                     mode="lines+markers", name="Compromised",
                                     line=dict(color="#f59e0b", width=1.5),
                                     marker=dict(size=3)), row=1, col=2)

        if path_lengths:
            fig.add_trace(go.Scatter(x=episodes[:len(path_lengths)], y=path_lengths,
                                     mode="lines", name="Path Length",
                                     line=dict(color="#a78bfa", width=1.5)), row=2, col=1)

        if success_rates:
            fig.add_trace(go.Scatter(x=episodes[:len(success_rates)], y=success_rates,
                                     mode="lines", name="Success Rate",
                                     fill="tozeroy",
                                     line=dict(color="#34d399", width=1.5)), row=2, col=2)

        fig.update_layout(
            title=dict(text="ARCA Training Metrics", font=dict(color="#f1f5f9", size=16)),
            paper_bgcolor="#0f172a",
            plot_bgcolor="#1e293b",
            font=dict(color="#f1f5f9"),
            showlegend=False,
            height=600,
        )
        fig.update_xaxes(gridcolor="#334155", zerolinecolor="#334155")
        fig.update_yaxes(gridcolor="#334155", zerolinecolor="#334155")

        if save:
            path = self.output_dir / "training_curves.html"
            fig.write_html(str(path))
            print(f"[ARCA] Training curves saved → {path}")
        if show:
            fig.show()
        return fig

    # ------------------------------------------------------------------
    # Attack Path
    # ------------------------------------------------------------------

    def plot_attack_path(
        self,
        attack_path: list[str],
        hosts: dict,
        save: bool = True,
        show: bool = False,
    ):
        if not PLOTLY_AVAILABLE or not attack_path:
            return None

        steps = list(range(len(attack_path)))
        labels = [p.split("(")[0] for p in attack_path]
        cvss = [p.split("CVE:")[1].rstrip(")") if "CVE:" in p else "?" for p in attack_path]

        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=steps, y=[1] * len(steps),
            mode="markers+text+lines",
            text=labels,
            textposition="top center",
            marker=dict(size=20, color="#ef4444",
                        symbol="arrow-right", line=dict(width=2, color="#fca5a5")),
            line=dict(color="#ef4444", width=2, dash="dot"),
            hovertext=[f"Step {i}: {p}<br>CVE: {c}" for i, (p, c) in enumerate(zip(attack_path, cvss))],
            hoverinfo="text",
        ))

        fig.update_layout(
            title=dict(text="ARCA Attack Path", font=dict(color="#f1f5f9", size=16)),
            paper_bgcolor="#0f172a",
            plot_bgcolor="#1e293b",
            font=dict(color="#f1f5f9"),
            xaxis=dict(title="Attack Step", gridcolor="#334155"),
            yaxis=dict(showticklabels=False, range=[0.5, 1.5]),
            height=300,
        )

        if save:
            path = self.output_dir / "attack_path.html"
            fig.write_html(str(path))
        if show:
            fig.show()
        return fig

    # ------------------------------------------------------------------
    # Vulnerability heatmap
    # ------------------------------------------------------------------

    def plot_vuln_heatmap(
        self,
        hosts: dict,
        save: bool = True,
        show: bool = False,
    ):
        if not PLOTLY_AVAILABLE:
            return None

        import numpy as np
        host_ids = sorted(hosts.keys())
        os_list = sorted({h.os for h in hosts.values()})
        matrix = np.zeros((len(os_list), len(host_ids)))

        for j, hid in enumerate(host_ids):
            host = hosts[hid]
            i = os_list.index(host.os)
            matrix[i][j] = len(host.vulnerabilities)

        fig = go.Figure(data=go.Heatmap(
            z=matrix,
            x=[f"H{i}" for i in host_ids],
            y=os_list,
            colorscale="Reds",
            showscale=True,
            text=matrix.astype(int),
            texttemplate="%{text}",
        ))
        fig.update_layout(
            title=dict(text="Vulnerability Density by Host & OS", font=dict(color="#f1f5f9")),
            paper_bgcolor="#0f172a",
            plot_bgcolor="#1e293b",
            font=dict(color="#f1f5f9"),
            height=350,
        )
        if save:
            path = self.output_dir / "vuln_heatmap.html"
            fig.write_html(str(path))
        if show:
            fig.show()
        return fig

    # ------------------------------------------------------------------
    # MPL fallback
    # ------------------------------------------------------------------

    def _mpl_network(self, graph, hosts, title, save):
        if not MPL_AVAILABLE or not NX_AVAILABLE:
            print("[ARCA] No visualization backend available.")
            return None

        color_map = {"UNKNOWN": "gray", "DISCOVERED": "orange", "COMPROMISED": "red"}
        colors = [color_map.get(hosts[n].status.name if hasattr(hosts[n].status, "name") else "UNKNOWN", "gray")
                  for n in graph.nodes() if n in hosts]

        fig, ax = plt.subplots(figsize=(10, 7), facecolor="#0f172a")
        ax.set_facecolor("#1e293b")
        pos = nx.spring_layout(graph, seed=42)
        nx.draw_networkx(graph, pos=pos, ax=ax, node_color=colors,
                         edge_color="#334155", font_color="white", node_size=500)
        ax.set_title(title, color="white")
        if save:
            path = self.output_dir / "network_topology.png"
            plt.savefig(str(path), dpi=150, bbox_inches="tight", facecolor="#0f172a")
            print(f"[ARCA] Network graph saved → {path}")
        plt.close()
        return fig



# ------------------- api -------------------


# File: arca/api/__init__.py
# ================================================
"""arca.api — FastAPI REST server."""

from arca.api.server import app

__all__ = ["app"]



# File: arca/api/server.py
# ================================================
"""
arca.api.server
===============
FastAPI REST interface for ARCA.

Start with:  arca serve
             uvicorn arca.api.server:app --reload --port 8000

Endpoints:
  GET  /             — health + version
  GET  /status       — current agent/env status
  POST /train        — start a training run
  POST /audit        — run an audit episode
  POST /reflect      — run LangGraph reflection
  GET  /presets      — list available network presets
"""

from __future__ import annotations

import time
from typing import Any, Optional

from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

from arca.__version__ import __version__

# ── App setup ────────────────────────────────────────────────────────────────

app = FastAPI(
    title="ARCA — Autonomous Reinforcement Cyber Agent",
    description=(
        "Fully local RL-powered pentesting simulation with LangGraph orchestration. "
        "All computation runs on your machine — no data leaves locally."
    ),
    version=__version__,
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── In-memory state (single-agent server) ────────────────────────────────────

_state: dict[str, Any] = {
    "agent": None,
    "env": None,
    "cfg": None,
    "training_active": False,
    "last_trained_at": None,
    "total_timesteps_trained": 0,
}


# ── Request / Response schemas ───────────────────────────────────────────────

class TrainRequest(BaseModel):
    preset: str = Field("small_office", description="Network preset name")
    timesteps: int = Field(50_000, ge=1_000, le=5_000_000, description="Training timesteps")
    algorithm: str = Field("PPO", description="RL algorithm: PPO | A2C | DQN")
    learning_rate: float = Field(3e-4, description="Learning rate")


class AuditRequest(BaseModel):
    preset: str = Field("small_office", description="Network preset")
    timesteps: int = Field(20_000, ge=500, description="Quick-train timesteps if no agent loaded")
    use_existing: bool = Field(True, description="Use already-trained agent if available")
    langgraph: bool = Field(False, description="Enable LangGraph LLM reflection")


class ReflectRequest(BaseModel):
    state: Optional[dict] = Field(None, description="Network state dict (auto-generates if None)")


# ── Routes ───────────────────────────────────────────────────────────────────

@app.get("/", tags=["Health"])
def root():
    return {
        "status": "ok",
        "service": "ARCA API",
        "version": __version__,
        "docs": "/docs",
        "agent_ready": _state["agent"] is not None,
    }


@app.get("/status", tags=["Health"])
def status():
    return {
        "agent_loaded": _state["agent"] is not None,
        "training_active": _state["training_active"],
        "last_trained_at": _state["last_trained_at"],
        "total_timesteps_trained": _state["total_timesteps_trained"],
        "preset": _state["cfg"].env.preset if _state["cfg"] else None,
    }


@app.get("/presets", tags=["Info"])
def list_presets():
    from arca.sim.environment import PRESETS
    return {
        "presets": {
            name: {
                "num_hosts": cfg.num_hosts,
                "num_subnets": cfg.num_subnets,
                "vulnerability_density": cfg.vulnerability_density,
                "max_steps": cfg.max_steps,
            }
            for name, cfg in PRESETS.items()
        }
    }


@app.post("/train", tags=["Training"])
def train(req: TrainRequest):
    """Train a new RL agent. Blocks until training is complete."""
    if _state["training_active"]:
        raise HTTPException(status_code=409, detail="Training already in progress.")

    try:
        from arca.core.config import ARCAConfig
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv

        cfg = ARCAConfig.default()
        cfg.env.preset = req.preset
        cfg.rl.algorithm = req.algorithm
        cfg.rl.learning_rate = req.learning_rate
        cfg.verbose = 0
        cfg.ensure_dirs()

        env = NetworkEnv.from_preset(req.preset, cfg=cfg)
        agent = ARCAAgent(env=env, cfg=cfg)

        _state["training_active"] = True
        start = time.time()
        agent.train(timesteps=req.timesteps, progress_bar=False)
        elapsed = round(time.time() - start, 2)

        _state["agent"] = agent
        _state["env"] = env
        _state["cfg"] = cfg
        _state["training_active"] = False
        _state["last_trained_at"] = time.strftime("%Y-%m-%dT%H:%M:%S")
        _state["total_timesteps_trained"] += req.timesteps

        return {
            "status": "success",
            "preset": req.preset,
            "algorithm": req.algorithm,
            "timesteps": req.timesteps,
            "elapsed_seconds": elapsed,
            "model_ready": True,
        }

    except Exception as e:
        _state["training_active"] = False
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/audit", tags=["Audit"])
def audit(req: AuditRequest):
    """Run a security audit episode and return a structured report."""
    try:
        from arca.core.config import ARCAConfig
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv

        agent = _state["agent"]
        env = _state["env"]
        cfg = _state["cfg"]

        if not req.use_existing or agent is None:
            cfg = ARCAConfig.default()
            cfg.env.preset = req.preset
            cfg.rl.n_steps = 64
            cfg.rl.batch_size = 32
            cfg.verbose = 0
            cfg.ensure_dirs()
            env = NetworkEnv.from_preset(req.preset, cfg=cfg)
            agent = ARCAAgent(env=env, cfg=cfg)
            agent.train(timesteps=req.timesteps, progress_bar=False)
            _state["agent"] = agent
            _state["env"] = env
            _state["cfg"] = cfg

        info = agent.run_episode()

        report = {
            "preset": cfg.env.preset,
            "total_reward": round(info.total_reward, 2),
            "steps": info.steps,
            "hosts_compromised": info.hosts_compromised,
            "hosts_discovered": info.hosts_discovered,
            "total_hosts": cfg.env.num_hosts,
            "goal_reached": info.goal_reached,
            "attack_path": info.attack_path,
            "summary": info.summary(),
        }

        if req.langgraph:
            agent.enable_langgraph()
            state = env.get_state_dict()
            reflection = agent.reflect(state)
            report["llm_analysis"] = reflection.get("reflection", "N/A")
            report["llm_plan"] = reflection.get("plan", "N/A")

        return report

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/reflect", tags=["LangGraph"])
def reflect(req: ReflectRequest):
    """Run a LangGraph reflection cycle on a network state."""
    try:
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        from arca.core.config import ARCAConfig

        agent = _state["agent"]
        env = _state["env"]

        if agent is None:
            cfg = ARCAConfig.default()
            env = NetworkEnv.from_preset("small_office", cfg=cfg)
            env.reset()
            agent = ARCAAgent(env=env, cfg=cfg)

        state = req.state or env.get_state_dict()
        agent.enable_langgraph()
        result = agent.reflect(state)
        return {"status": "ok", "reflection": result}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))



# ------------------- utils -------------------


# File: arca/utils/__init__.py
# ================================================
"""arca.utils — Shared utilities."""

from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Any


def save_json(data: dict, path: str | Path) -> None:
    """Save a dict as pretty-printed JSON."""
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=2, default=str)


def load_json(path: str | Path) -> dict:
    """Load a JSON file."""
    with open(path) as f:
        return json.load(f)


def smooth(data: list[float], window: int = 5) -> list[float]:
    """Simple moving-average smoothing."""
    import numpy as np
    if len(data) < window:
        return data
    kernel = [1.0 / window] * window
    result = []
    for i in range(len(data) - window + 1):
        result.append(sum(data[i:i+window]) / window)
    return result


def timestamp() -> str:
    """Return a sortable timestamp string like 20260412_153012."""
    return time.strftime("%Y%m%d_%H%M%S")


class Timer:
    """Simple context-manager timer."""
    def __enter__(self):
        self._start = time.time()
        return self

    def __exit__(self, *args):
        self.elapsed = time.time() - self._start

    def __str__(self):
        return f"{self.elapsed:.2f}s"


__all__ = ["save_json", "load_json", "smooth", "timestamp", "Timer"]



# ------------------- cpp_ext -------------------


# File: arca/cpp_ext/__init__.py
# ================================================
"""C++ extension module. Falls back gracefully if not compiled."""

try:
    from arca._cpp_sim import compute_reachability, floyd_warshall, batch_exploit  # type: ignore
    CPP_AVAILABLE = True
except ImportError:
    CPP_AVAILABLE = False

    def compute_reachability(adj, n_nodes):
        """Pure-Python fallback."""
        from collections import deque
        reach = [[False] * n_nodes for _ in range(n_nodes)]
        for src in range(n_nodes):
            visited = [False] * n_nodes
            q = deque([src])
            visited[src] = True
            reach[src][src] = True
            while q:
                u = q.popleft()
                for v in (adj[u] if u < len(adj) else []):
                    if not visited[v]:
                        visited[v] = True
                        reach[src][v] = True
                        q.append(v)
        return reach

    def floyd_warshall(weights, n):
        """Pure-Python fallback."""
        import math
        dist = [row[:] for row in weights]
        for k in range(n):
            for i in range(n):
                for j in range(n):
                    if dist[i][k] + dist[k][j] < dist[i][j]:
                        dist[i][j] = dist[i][k] + dist[k][j]
        return dist

    def batch_exploit(hosts, actions, seed=42):
        import random
        rng = random.Random(seed)
        results = []
        for target, exploit_id in actions:
            if target >= len(hosts):
                results.append({"success": False, "reward": -1.0, "compromised_host": -1})
                continue
            prob = hosts[target].get("exploit_prob", 0.5)
            success = rng.random() < prob
            results.append({
                "success": success,
                "reward": 20.0 if success else -0.5,
                "compromised_host": target if success else -1,
            })
        return results


__all__ = ["CPP_AVAILABLE", "compute_reachability", "floyd_warshall", "batch_exploit"]



# File: arca/cpp_ext/sim_engine.cpp
# ================================================
/*
 * arca/cpp_ext/sim_engine.cpp
 * ===========================
 * Performance-critical simulation primitives exposed to Python via pybind11.
 *
 * Exposes:
 *   SimEngine.compute_reachability(adj_matrix) -> reachability_matrix
 *   SimEngine.batch_exploit(hosts, actions)    -> results vector
 *   SimEngine.floyd_warshall(adj)              -> shortest paths
 *
 * Build: pip install pybind11 && pip install -e ".[cpp]"
 */

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <vector>
#include <queue>
#include <limits>
#include <random>
#include <unordered_map>

namespace py = pybind11;

// ------------------------------------------------------------------ //
//  BFS-based reachability computation (faster than networkx for      //
//  dense adjacency on small graphs)                                  //
// ------------------------------------------------------------------ //
std::vector<std::vector<bool>> compute_reachability(
    const std::vector<std::vector<int>>& adj,
    int n_nodes
) {
    std::vector<std::vector<bool>> reach(n_nodes, std::vector<bool>(n_nodes, false));

    for (int src = 0; src < n_nodes; ++src) {
        std::vector<bool> visited(n_nodes, false);
        std::queue<int> q;
        q.push(src);
        visited[src] = true;
        reach[src][src] = true;
        while (!q.empty()) {
            int u = q.front(); q.pop();
            if (u < (int)adj.size()) {
                for (int v : adj[u]) {
                    if (!visited[v]) {
                        visited[v] = true;
                        reach[src][v] = true;
                        q.push(v);
                    }
                }
            }
        }
    }
    return reach;
}

// ------------------------------------------------------------------ //
//  Floyd-Warshall for all-pairs shortest path                        //
// ------------------------------------------------------------------ //
std::vector<std::vector<double>> floyd_warshall(
    const std::vector<std::vector<double>>& weights,
    int n
) {
    const double INF = std::numeric_limits<double>::infinity();
    std::vector<std::vector<double>> dist(weights);

    for (int k = 0; k < n; ++k)
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < n; ++j)
                if (dist[i][k] + dist[k][j] < dist[i][j])
                    dist[i][j] = dist[i][k] + dist[k][j];
    return dist;
}

// ------------------------------------------------------------------ //
//  Batch exploit simulation                                          //
//  Returns (success, reward) pairs for each action                   //
// ------------------------------------------------------------------ //
struct ExploitResult {
    bool success;
    double reward;
    int compromised_host;
};

std::vector<ExploitResult> batch_exploit(
    const std::vector<std::unordered_map<std::string, double>>& hosts,
    const std::vector<std::pair<int, int>>& actions,  // (target_host, exploit_id)
    uint64_t seed
) {
    std::mt19937_64 rng(seed);
    std::uniform_real_distribution<double> dist(0.0, 1.0);

    std::vector<ExploitResult> results;
    results.reserve(actions.size());

    for (auto& [target, exploit_id] : actions) {
        if (target >= (int)hosts.size()) {
            results.push_back({false, -1.0, -1});
            continue;
        }
        auto& host = hosts[target];
        double prob = 0.5;  // default
        auto it = host.find("exploit_prob");
        if (it != host.end()) prob = it->second;

        bool success = dist(rng) < prob;
        double reward = success ? 20.0 : -0.5;
        results.push_back({success, reward, success ? target : -1});
    }
    return results;
}

// ------------------------------------------------------------------ //
//  pybind11 module                                                   //
// ------------------------------------------------------------------ //
PYBIND11_MODULE(_cpp_sim, m) {
    m.doc() = "ARCA C++ accelerated simulation engine";

    py::class_<ExploitResult>(m, "ExploitResult")
        .def_readonly("success", &ExploitResult::success)
        .def_readonly("reward", &ExploitResult::reward)
        .def_readonly("compromised_host", &ExploitResult::compromised_host);

    m.def("compute_reachability", &compute_reachability,
          py::arg("adj"), py::arg("n_nodes"),
          "BFS-based all-pairs reachability. Returns bool[n][n] matrix.");

    m.def("floyd_warshall", &floyd_warshall,
          py::arg("weights"), py::arg("n"),
          "All-pairs shortest path via Floyd-Warshall.");

    m.def("batch_exploit", &batch_exploit,
          py::arg("hosts"), py::arg("actions"), py::arg("seed") = 42ULL,
          "Batch exploit simulation. Returns list of ExploitResult.");

    // Version info
    m.attr("__version__") = "0.1.0";
    m.attr("__cpp_available__") = true;
}



# ==================== EXAMPLES & TESTS ====================


# File: examples/quickstart.py
# ================================================
#!/usr/bin/env python3
"""
examples/quickstart.py
======================
ARCA end-to-end quickstart.

Run from project root:
    python examples/quickstart.py

Demonstrates:
  1. Environment creation from preset
  2. PPO training (10k steps)
  3. Episode evaluation
  4. LangGraph reflection (rule-based fallback if no Ollama)
  5. Visualization suite (saves HTML files)
  6. Model save
"""

from __future__ import annotations

import random
import sys
from pathlib import Path

# Allow running from project root without installing
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))


def banner(msg: str):
    print(f"\n{'='*55}")
    print(f"  {msg}")
    print('='*55)


def main():
    banner("ARCA — Autonomous Reinforcement Cyber Agent")
    print("Quickstart example — all local, no cloud required.\n")

    # ── 1. Configuration ──────────────────────────────────────────────
    banner("1 / 6  Configuration")
    from arca.core.config import ARCAConfig
    cfg = ARCAConfig.default()
    cfg.env.preset = "small_office"
    cfg.rl.algorithm = "PPO"
    cfg.rl.total_timesteps = 10_000
    cfg.rl.n_steps = 128
    cfg.rl.batch_size = 32
    cfg.verbose = 1
    cfg.ensure_dirs()
    print(f"  Preset    : {cfg.env.preset}")
    print(f"  Hosts     : {cfg.env.num_hosts}")
    print(f"  Subnets   : {cfg.env.num_subnets}")
    print(f"  Algorithm : {cfg.rl.algorithm}")
    print(f"  Steps     : {cfg.rl.total_timesteps:,}")

    # ── 2. Environment ────────────────────────────────────────────────
    banner("2 / 6  Create Network Environment")
    from arca.sim.environment import NetworkEnv
    env = NetworkEnv.from_preset("small_office", cfg=cfg)
    obs, info = env.reset()
    print(f"  Observation shape : {obs.shape}")
    print(f"  Action space      : {env.action_space}")
    print(f"  Attacker starts   : Host {info['attacker_node']}")
    print()
    print(env.render())

    # ── 3. Training ───────────────────────────────────────────────────
    banner("3 / 6  Train PPO Agent (10,000 steps)")
    from arca.core.agent import ARCAAgent
    agent = ARCAAgent(env=env, cfg=cfg)
    agent.train(timesteps=10_000, progress_bar=True)
    print("  Training complete!")

    # ── 4. Evaluation ────────────────────────────────────────────────
    banner("4 / 6  Evaluate (3 episodes)")
    for i in range(3):
        info_ep = agent.run_episode(render=False)
        print(f"  Episode {i+1}: {info_ep.summary()}")
        if info_ep.attack_path:
            print(f"    Attack path: {' → '.join(info_ep.attack_path[:4])}{'...' if len(info_ep.attack_path) > 4 else ''}")

    # ── 5. LangGraph Reflection ───────────────────────────────────────
    banner("5 / 6  LangGraph Reflection")
    print("  (Uses local Ollama if running; falls back to rule-based logic)\n")
    agent.enable_langgraph()
    env.reset()
    state = env.get_state_dict()
    result = agent.reflect(state)

    print(f"  Analyst:    {str(result.get('analyst_output', 'N/A'))[:180]}")
    print(f"\n  Critic:     {str(result.get('critic_output', 'N/A'))[:180]}")
    print(f"\n  Reflection: {str(result.get('reflection', 'N/A'))[:180]}")
    plan = result.get("plan", "N/A")
    if isinstance(plan, list):
        print(f"\n  Plan:")
        for step in plan[:3]:
            print(f"    - {step}")
    else:
        print(f"\n  Plan: {str(plan)[:200]}")

    # ── 6. Visualization ─────────────────────────────────────────────
    banner("6 / 6  Visualizations")
    from arca.viz.visualizer import ARCAVisualizer
    viz = ARCAVisualizer(output_dir="arca_outputs/figures")
    env.reset()

    try:
        viz.plot_network(env.get_network_graph(), env.get_hosts(), save=True, show=False)
        print("  ✓ Network topology → arca_outputs/figures/network_topology.html")
    except Exception as e:
        print(f"  ✗ Network plot: {e}")

    try:
        viz.plot_vuln_heatmap(env.get_hosts(), save=True, show=False)
        print("  ✓ Vulnerability heatmap → arca_outputs/figures/vuln_heatmap.html")
    except Exception as e:
        print(f"  ✗ Vuln heatmap: {e}")

    try:
        n = 40
        log_data = {
            "episodes": list(range(n)),
            "rewards": [random.gauss(10 * (1 + i / n), 4) for i in range(n)],
            "compromised": [random.randint(1, 6) for _ in range(n)],
            "path_lengths": [random.randint(1, 8) for _ in range(n)],
            "success_rates": [min(1.0, 0.2 + 0.5 * i / n + random.gauss(0, 0.05)) for i in range(n)],
        }
        viz.plot_training_curves(log_data, save=True, show=False)
        print("  ✓ Training curves  → arca_outputs/figures/training_curves.html")
    except Exception as e:
        print(f"  ✗ Training curves: {e}")

    # ── Save ─────────────────────────────────────────────────────────
    path = agent.save()
    print(f"\n  ✓ Model saved → {path}")

    # ── Summary ──────────────────────────────────────────────────────
    banner("ARCA Quickstart Complete!")
    print("\n  Next steps:")
    print("    arca train --timesteps 100000    # longer training")
    print("    arca serve                       # REST API at :8000")
    print("    arca audit --preset enterprise   # full audit report")
    print("    arca viz --show                  # open plots in browser")
    print("    arca info                        # system info\n")


if __name__ == "__main__":
    main()



# File: tests/test_arca.py
# ================================================
"""
tests/test_arca.py
==================
Comprehensive test suite for ARCA.

Run with:
    pytest tests/ -v
    pytest tests/ -v --tb=short -x   # stop on first failure
"""

from __future__ import annotations

import numpy as np
import pytest


# ──────────────────────────────────────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────────────────────────────────────

class TestARCAConfig:
    def test_default_config(self):
        from arca.core.config import ARCAConfig
        cfg = ARCAConfig.default()
        assert cfg.env.num_hosts == 10
        assert cfg.rl.algorithm == "PPO"
        assert cfg.verbose in (0, 1)

    def test_config_ensure_dirs(self, tmp_path):
        from arca.core.config import ARCAConfig
        cfg = ARCAConfig.default()
        cfg.model_dir = str(tmp_path / "models")
        cfg.log_dir = str(tmp_path / "logs")
        cfg.viz.output_dir = str(tmp_path / "figures")
        cfg.ensure_dirs()
        assert (tmp_path / "models").exists()
        assert (tmp_path / "logs").exists()
        assert (tmp_path / "figures").exists()

    def test_config_yaml_roundtrip(self, tmp_path):
        from arca.core.config import ARCAConfig
        cfg = ARCAConfig.default()
        cfg.env.num_hosts = 15
        cfg.rl.learning_rate = 1e-4
        yaml_path = tmp_path / "config.yaml"
        cfg.to_yaml(yaml_path)
        cfg2 = ARCAConfig.from_yaml(yaml_path)
        assert cfg2.env.num_hosts == 15
        assert abs(cfg2.rl.learning_rate - 1e-4) < 1e-10


# ──────────────────────────────────────────────────────────────────────────────
# HOST / ACTION
# ──────────────────────────────────────────────────────────────────────────────

class TestHostAndAction:
    def test_host_creation(self):
        from arca.sim.host import Host, HostStatus
        h = Host(id=0, subnet=0, os="Linux", ip="10.0.1.1")
        assert h.status == HostStatus.UNKNOWN
        assert not h.discovered
        d = h.to_dict()
        assert d["id"] == 0
        assert d["os"] == "Linux"

    def test_host_status_transitions(self):
        from arca.sim.host import Host, HostStatus
        h = Host(id=1, subnet=0, os="Windows", ip="10.0.1.2")
        h.discovered = True
        h.status = HostStatus.COMPROMISED
        assert h.status == HostStatus.COMPROMISED

    def test_action_creation(self):
        from arca.sim.action import Action, ActionType, ActionResult
        act = Action(action_type=ActionType.SCAN, target_host=3, exploit_id=0, source_host=0)
        assert act.action_type == ActionType.SCAN
        d = act.to_dict()
        assert d["type"] == "SCAN"

    def test_action_result(self):
        from arca.sim.action import ActionResult
        r = ActionResult(success=True, message="Scan OK", discovered_hosts=[2])
        assert r.success
        assert 2 in r.discovered_hosts
        d = r.to_dict()
        assert d["success"] is True


# ──────────────────────────────────────────────────────────────────────────────
# NETWORK GENERATOR
# ──────────────────────────────────────────────────────────────────────────────

class TestNetworkGenerator:
    def test_generator_creates_graph(self):
        import random
        from arca.sim.network_generator import NetworkGenerator
        from arca.core.config import EnvConfig
        cfg = EnvConfig(num_hosts=8, num_subnets=2)
        gen = NetworkGenerator(cfg, rng=random.Random(42))
        graph, hosts = gen.generate()
        assert len(hosts) == 8
        assert graph.number_of_nodes() == 8

    def test_attacker_node_in_subnet0(self):
        import random
        from arca.sim.network_generator import NetworkGenerator
        from arca.core.config import EnvConfig
        cfg = EnvConfig(num_hosts=8, num_subnets=2)
        gen = NetworkGenerator(cfg, rng=random.Random(0))
        _, hosts = gen.generate()
        assert hosts[gen.attacker_node].subnet == 0

    def test_vulnerability_assignment(self):
        import random
        from arca.sim.network_generator import NetworkGenerator
        from arca.core.config import EnvConfig
        cfg = EnvConfig(num_hosts=20, num_subnets=3, vulnerability_density=1.0)
        gen = NetworkGenerator(cfg, rng=random.Random(7))
        _, hosts = gen.generate()
        any_vulns = any(len(h.vulnerabilities) > 0 for h in hosts.values())
        assert any_vulns


# ──────────────────────────────────────────────────────────────────────────────
# NETWORK ENVIRONMENT
# ──────────────────────────────────────────────────────────────────────────────

class TestNetworkEnv:
    def test_env_reset(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        obs, info = env.reset()
        assert obs.shape == env.observation_space.shape
        assert "attacker_node" in info

    def test_env_step_returns_valid_obs(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        obs, _ = env.reset()
        action = env.action_space.sample()
        obs2, reward, terminated, truncated, info = env.step(action)
        assert obs2.shape == env.observation_space.shape
        assert isinstance(reward, float)
        assert isinstance(terminated, bool)
        assert isinstance(truncated, bool)

    def test_env_render_returns_string(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        env.reset()
        rendered = env.render()
        assert isinstance(rendered, str)
        assert "ARCA" in rendered or "Host" in rendered

    def test_env_observation_in_bounds(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        obs, _ = env.reset()
        assert np.all(obs >= env.observation_space.low)
        assert np.all(obs <= env.observation_space.high)

    def test_env_presets(self):
        from arca.sim.environment import NetworkEnv
        for preset in ["small_office", "enterprise", "dmz", "iot_network"]:
            env = NetworkEnv.from_preset(preset)
            obs, info = env.reset()
            assert obs is not None, f"Preset {preset} failed to reset"

    def test_env_episode_terminates(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        obs, _ = env.reset()
        done = False
        max_iters = 500
        for _ in range(max_iters):
            action = env.action_space.sample()
            obs, reward, terminated, truncated, _ = env.step(action)
            if terminated or truncated:
                done = True
                break
        assert done, "Episode should terminate within max_steps"

    def test_env_get_state_dict(self):
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        env.reset()
        state = env.get_state_dict()
        assert "hosts" in state
        assert "attacker_node" in state
        assert "episode_info" in state

    def test_env_action_space_discrete(self):
        from arca.sim.environment import NetworkEnv
        from gymnasium.spaces import Discrete
        env = NetworkEnv.from_preset("small_office")
        assert isinstance(env.action_space, Discrete)

    def test_env_from_custom_config(self):
        from arca.sim.environment import NetworkEnv
        from arca.core.config import ARCAConfig, EnvConfig
        cfg = ARCAConfig.default()
        cfg.env = EnvConfig(num_hosts=6, num_subnets=2, max_steps=50)
        env = NetworkEnv(cfg=cfg)
        obs, _ = env.reset()
        assert obs.shape[0] == 6 * env._HOST_FEATURES


# ──────────────────────────────────────────────────────────────────────────────
# C++ EXTENSION (both CPP and fallback must work)
# ──────────────────────────────────────────────────────────────────────────────

class TestCppExtension:
    def test_import_does_not_crash(self):
        from arca.cpp_ext import CPP_AVAILABLE, compute_reachability, floyd_warshall, batch_exploit
        assert isinstance(CPP_AVAILABLE, bool)

    def test_compute_reachability_2node(self):
        from arca.cpp_ext import compute_reachability
        adj = [[1], [0]]  # node0 → node1, node1 → node0
        reach = compute_reachability(adj, 2)
        assert reach[0][1] is True
        assert reach[1][0] is True

    def test_compute_reachability_disconnected(self):
        from arca.cpp_ext import compute_reachability
        adj = [[], []]  # no edges
        reach = compute_reachability(adj, 2)
        assert reach[0][0] is True   # self
        assert reach[0][1] is False  # disconnected

    def test_floyd_warshall_simple(self):
        from arca.cpp_ext import floyd_warshall
        import math
        INF = math.inf
        w = [
            [0,   1,   INF],
            [INF, 0,   2  ],
            [INF, INF, 0  ],
        ]
        dist = floyd_warshall(w, 3)
        assert dist[0][2] == 3.0  # 0→1→2

    def test_batch_exploit_success_rate(self):
        from arca.cpp_ext import batch_exploit
        hosts = [{"exploit_prob": 1.0}] * 5
        actions = [(i, 0) for i in range(5)]
        results = batch_exploit(hosts, actions, seed=42)
        assert len(results) == 5
        for r in results:
            if isinstance(r, dict):
                assert r["success"] is True
            else:
                assert r.success is True

    def test_batch_exploit_zero_prob(self):
        from arca.cpp_ext import batch_exploit
        hosts = [{"exploit_prob": 0.0}] * 3
        actions = [(0, 0), (1, 0), (2, 0)]
        results = batch_exploit(hosts, actions, seed=1)
        for r in results:
            success = r["success"] if isinstance(r, dict) else r.success
            assert success is False


# ──────────────────────────────────────────────────────────────────────────────
# AGENT (no training — just instantiation and env interaction)
# ──────────────────────────────────────────────────────────────────────────────

class TestARCAAgent:
    def test_agent_instantiation(self):
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        agent = ARCAAgent(env=env)
        assert agent is not None
        assert agent.env is env

    def test_agent_repr(self):
        from arca.core.agent import ARCAAgent
        agent = ARCAAgent()
        r = repr(agent)
        assert "ARCAAgent" in r

    def test_agent_reflect_without_llm(self):
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        env.reset()
        agent = ARCAAgent(env=env)
        state = env.get_state_dict()
        # Should not raise even without Ollama running
        result = agent.reflect(state)
        assert isinstance(result, dict)

    def test_agent_train_and_run_episode(self):
        """Quick 1000-step training + 1 episode — validates full pipeline."""
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        from arca.core.config import ARCAConfig
        cfg = ARCAConfig.default()
        cfg.rl.n_steps = 64
        cfg.rl.batch_size = 32
        cfg.verbose = 0
        env = NetworkEnv.from_preset("small_office", cfg=cfg)
        agent = ARCAAgent(env=env, cfg=cfg)
        agent.train(timesteps=1000, progress_bar=False)
        info = agent.run_episode()
        assert info is not None
        assert info.steps > 0

    def test_agent_save_and_load(self, tmp_path):
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        from arca.core.config import ARCAConfig
        cfg = ARCAConfig.default()
        cfg.model_dir = str(tmp_path / "models")
        cfg.rl.n_steps = 64
        cfg.rl.batch_size = 32
        cfg.verbose = 0
        env = NetworkEnv.from_preset("small_office", cfg=cfg)
        agent = ARCAAgent(env=env, cfg=cfg)
        agent.train(timesteps=500, progress_bar=False)
        save_path = str(tmp_path / "model")
        agent.save(save_path)
        assert (tmp_path / "model.zip").exists()

        agent2 = ARCAAgent(env=env, cfg=cfg)
        agent2.load(save_path)
        info = agent2.run_episode()
        assert info.steps > 0


# ──────────────────────────────────────────────────────────────────────────────
# VISUALIZER
# ──────────────────────────────────────────────────────────────────────────────

class TestARCAVisualizer:
    def test_visualizer_instantiation(self, tmp_path):
        from arca.viz.visualizer import ARCAVisualizer
        viz = ARCAVisualizer(output_dir=str(tmp_path))
        assert viz.output_dir.exists()

    def test_plot_network_saves_html(self, tmp_path):
        from arca.viz.visualizer import ARCAVisualizer
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        env.reset()
        viz = ARCAVisualizer(output_dir=str(tmp_path))
        viz.plot_network(env.get_network_graph(), env.get_hosts(), save=True, show=False)
        outputs = list(tmp_path.iterdir())
        assert len(outputs) >= 1

    def test_plot_vuln_heatmap_saves(self, tmp_path):
        from arca.viz.visualizer import ARCAVisualizer
        from arca.sim.environment import NetworkEnv
        env = NetworkEnv.from_preset("small_office")
        env.reset()
        viz = ARCAVisualizer(output_dir=str(tmp_path))
        viz.plot_vuln_heatmap(env.get_hosts(), save=True, show=False)

    def test_plot_training_curves_saves(self, tmp_path):
        from arca.viz.visualizer import ARCAVisualizer
        import random
        n = 20
        log_data = {
            "episodes": list(range(n)),
            "rewards": [random.gauss(5, 2) for _ in range(n)],
            "compromised": [random.randint(1, 4) for _ in range(n)],
            "path_lengths": [random.randint(1, 6) for _ in range(n)],
            "success_rates": [random.uniform(0.2, 0.8) for _ in range(n)],
        }
        viz = ARCAVisualizer(output_dir=str(tmp_path))
        viz.plot_training_curves(log_data, save=True, show=False)


# ──────────────────────────────────────────────────────────────────────────────
# EPISODE INFO
# ──────────────────────────────────────────────────────────────────────────────

class TestEpisodeInfo:
    def test_summary_string(self):
        from arca.sim.environment import EpisodeInfo
        info = EpisodeInfo(
            total_reward=42.5,
            steps=100,
            hosts_compromised=3,
            hosts_discovered=5,
            goal_reached=True,
        )
        s = info.summary()
        assert "42.5" in s
        assert "100" in s


# ──────────────────────────────────────────────────────────────────────────────
# INTEGRATION: Full pipeline smoke test
# ──────────────────────────────────────────────────────────────────────────────

class TestIntegration:
    def test_full_pipeline_small(self, tmp_path):
        """Smoke test: config → env → agent → train → eval → reflect → viz."""
        from arca.core.config import ARCAConfig
        from arca.core.agent import ARCAAgent
        from arca.sim.environment import NetworkEnv
        from arca.viz.visualizer import ARCAVisualizer
        import random

        cfg = ARCAConfig.default()
        cfg.env.preset = "small_office"
        cfg.rl.n_steps = 64
        cfg.rl.batch_size = 32
        cfg.verbose = 0
        cfg.model_dir = str(tmp_path / "models")
        cfg.log_dir = str(tmp_path / "logs")
        cfg.viz.output_dir = str(tmp_path / "figures")
        cfg.ensure_dirs()

        env = NetworkEnv.from_preset("small_office", cfg=cfg)
        agent = ARCAAgent(env=env, cfg=cfg)
        agent.train(timesteps=500, progress_bar=False)

        info = agent.run_episode()
        assert info.steps > 0

        state = env.get_state_dict()
        reflection = agent.reflect(state)
        assert isinstance(reflection, dict)

        viz = ARCAVisualizer(output_dir=str(tmp_path / "figures"))
        env.reset()
        viz.plot_network(env.get_network_graph(), env.get_hosts(), save=True, show=False)

        print(f"\n[Integration] Steps={info.steps}, Compromised={info.hosts_compromised}, Goal={info.goal_reached}")



# ================================================
# End of ARCA codebase context
# Total files included: 21
