wactorz_interfaces/
ws.rs

1//! WebSocket routes for the Wactorz server.
2//!
3//! Two routes are mounted under the same axum `Router`:
4//!
5//! - `/ws`   — Python-compatible aggregated-state bridge.
6//!   Compatible with `monitor.html` (and any client expecting
7//!   `full_snapshot` / `patch` / `delete_agent` JSON messages).
8//!
9//! - `/mqtt` — Transparent WebSocket proxy to the Mosquitto broker's WS
10//!   listener (configurable host/port, default `localhost:9001`).
11//!   Compatible with `mqtt.js` / `frontend/dist/index.html`.
12//!
13//! Together these two routes ensure **any combination** of
14//! `python|rust` backend × `monitor.html|frontend/dist/index.html` frontend
15//! works without any client-side changes.
16//!
17//! ## `/ws` message protocol  (mirrors `monitor_server.py`)
18//!
19//! **Server → browser** on connect:
20//! ```json
21//! { "type": "full_snapshot", "state": { "agents": [...], "nodes": [...], ... } }
22//! ```
23//! **Server → browser** on MQTT event:
24//! ```json
25//! { "type": "patch", "event": { ... }, "state": { ... } }
26//! ```
27//! **Server → browser** after delete command:
28//! ```json
29//! { "type": "delete_agent", "agent_id": "...", "state": { ... } }
30//! ```
31//! **Browser → server** (commands):
32//! ```json
33//! { "type": "command", "command": "pause|stop|resume|delete", "agent_id": "..." }
34//! ```
35
36use std::collections::HashMap;
37use std::sync::Arc;
38use std::time::{SystemTime, UNIX_EPOCH};
39
40use axum::{
41    Router,
42    extract::{
43        State,
44        ws::{Message, WebSocket, WebSocketUpgrade},
45    },
46    http::HeaderMap,
47    response::IntoResponse,
48    routing::get,
49};
50use futures_util::{SinkExt, StreamExt};
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use tokio::sync::{Mutex, broadcast, mpsc};
54
55use wactorz_core::{ActorSystem, Message as ActorMessage};
56use wactorz_mqtt::MqttClient;
57
58const AGENT_STALE_SECS: f64 = 90.0;
59const TERMINAL_AGENT_GRACE_SECS: f64 = 15.0;
60
61// ── Internal MQTT envelope (Rust MQTT loop → WS state aggregator) ─────────────
62
63/// Raw MQTT message forwarded from the broker event loop.
64/// Consumed by [`WsBridge::spawn_monitor_task`]; not sent to browser clients.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct WsEnvelope {
67    pub topic: String,
68    pub payload: Value,
69}
70
71// ── Monitor state ─────────────────────────────────────────────────────────────
72
73fn now_secs() -> f64 {
74    SystemTime::now()
75        .duration_since(UNIX_EPOCH)
76        .unwrap_or_default()
77        .as_secs_f64()
78}
79
80/// Mirrors the in-memory state maintained by Python's `monitor_server.py`.
81#[derive(Debug, Default)]
82pub struct MonitorState {
83    agents: HashMap<String, Value>,
84    nodes: HashMap<String, Value>,
85    alerts: Vec<Value>,
86    log_feed: Vec<Value>,
87    system_health: Value,
88}
89
90impl MonitorState {
91    fn prune_stale(&mut self) {
92        let now = now_secs();
93        self.agents.retain(|_, agent| {
94            let last_update = agent
95                .get("last_update")
96                .and_then(|v| v.as_f64())
97                .unwrap_or(now);
98            let state = agent.get("state").and_then(|v| v.as_str()).unwrap_or("");
99            let max_age = match state {
100                "stopped" | "failed" => TERMINAL_AGENT_GRACE_SECS,
101                _ => AGENT_STALE_SECS,
102            };
103            now - last_update <= max_age
104        });
105    }
106
107    /// Serialisable snapshot sent to browser clients.
108    pub fn snapshot(&mut self) -> Value {
109        self.prune_stale();
110        let agents: Vec<Value> = self.agents.values().cloned().collect();
111        let nodes: Vec<Value> = self.nodes.values().cloned().collect();
112        let total_cost: f64 = self
113            .agents
114            .values()
115            .filter_map(|a| a.get("cost_usd").and_then(|v| v.as_f64()))
116            .sum();
117        let alert_end = self.alerts.len().min(10);
118        let log_end = self.log_feed.len().min(20);
119        json!({
120            "agents":          agents,
121            "nodes":           nodes,
122            "alerts":          &self.alerts[..alert_end],
123            "log_feed":        &self.log_feed[..log_end],
124            "system_health":   self.system_health,
125            "total_cost_usd":  (total_cost * 1_000_000.0).round() / 1_000_000.0,
126        })
127    }
128
129    fn update_agent(&mut self, agent_id: &str, key: &str, data: Value) {
130        let short = &agent_id[..agent_id.len().min(8)];
131        let entry = self.agents.entry(agent_id.to_string()).or_insert_with(|| {
132            json!({
133                "agent_id":   agent_id,
134                "name":       short,
135                "first_seen": now_secs(),
136            })
137        });
138        if let Some(obj) = entry.as_object_mut() {
139            obj.insert(key.to_string(), data);
140            obj.insert("last_update".to_string(), json!(now_secs()));
141        }
142    }
143
144    fn add_log(&mut self, entry: Value) {
145        self.log_feed.insert(0, entry);
146        if self.log_feed.len() > 100 {
147            self.log_feed.pop();
148        }
149    }
150
151    /// Parse one MQTT message and update internal state.
152    ///
153    /// Returns `Some((event, is_heartbeat))` when something should be
154    /// broadcast, or `None` when the topic is not recognised.
155    /// `is_heartbeat` suppresses the event from the browser's log feed
156    /// (mirrors Python behaviour).
157    pub fn parse_topic(&mut self, topic: &str, payload: Value) -> Option<(Value, bool)> {
158        let parts: Vec<&str> = topic.split('/').collect();
159
160        // ── system/# ────────────────────────────────────────────────────────
161        if parts[0] == "system" && parts.len() >= 2 {
162            match parts[1] {
163                "health" => {
164                    self.system_health = payload.clone();
165                }
166                "alerts" => {
167                    self.alerts.insert(0, payload.clone());
168                    if self.alerts.len() > 50 {
169                        self.alerts.pop();
170                    }
171                }
172                _ => {}
173            }
174            return Some((
175                json!({
176                    "type":    "system",
177                    "subtype": parts[1],
178                    "data":    payload,
179                }),
180                false,
181            ));
182        }
183
184        // ── agents/{id}/{metric} ─────────────────────────────────────────────
185        if parts[0] == "agents" && parts.len() >= 3 {
186            let agent_id = parts[1];
187            let metric = parts[2];
188
189            match metric {
190                "status" => {
191                    self.update_agent(agent_id, "status", payload.clone());
192                    if let Some(obj) = payload.as_object()
193                        && let Some(entry) = self.agents.get_mut(agent_id)
194                        && let Some(e) = entry.as_object_mut()
195                    {
196                        if let Some(n) = obj.get("name") {
197                            e.insert("name".into(), n.clone());
198                        }
199                        if let Some(s) = obj.get("state") {
200                            e.insert("state".into(), s.clone());
201                        }
202                    }
203                    self.add_log(json!({
204                        "type":      "status",
205                        "agent_id":  agent_id,
206                        "status":    payload,
207                        "timestamp": now_secs(),
208                    }));
209                }
210                "heartbeat" => {
211                    self.update_agent(agent_id, "heartbeat", payload.clone());
212                    if let Some(obj) = payload.as_object() {
213                        let short = &agent_id[..agent_id.len().min(8)];
214                        let name = obj.get("name").and_then(|v| v.as_str()).unwrap_or(short);
215                        if let Some(entry) = self.agents.get_mut(agent_id)
216                            && let Some(e) = entry.as_object_mut()
217                        {
218                            e.insert("name".into(), json!(name));
219                            for k in &["cpu", "state"] {
220                                if let Some(v) = obj.get(*k) {
221                                    e.insert(k.to_string(), v.clone());
222                                }
223                            }
224                            if let Some(v) = obj.get("memory_mb") {
225                                e.insert("mem".into(), v.clone());
226                            }
227                            if let Some(v) = obj.get("task") {
228                                e.insert("task".into(), v.clone());
229                            }
230                        }
231                    }
232                    // heartbeat → broadcast state update but suppress from log_feed
233                    return Some((
234                        json!({
235                            "type":     "agent",
236                            "agent_id": agent_id,
237                            "metric":   metric,
238                            "data":     payload,
239                        }),
240                        true,
241                    ));
242                }
243                "metrics" => {
244                    self.update_agent(agent_id, "metrics", payload.clone());
245                    if let Some(obj) = payload.as_object()
246                        && let Some(entry) = self.agents.get_mut(agent_id)
247                        && let Some(e) = entry.as_object_mut()
248                    {
249                        for k in &[
250                            "messages_processed",
251                            "cost_usd",
252                            "input_tokens",
253                            "output_tokens",
254                        ] {
255                            if let Some(v) = obj.get(*k) {
256                                e.insert(k.to_string(), v.clone());
257                            }
258                        }
259                    }
260                }
261                "logs" => {
262                    let mut log = json!({
263                        "type":      "log",
264                        "agent_id":  agent_id,
265                        "timestamp": now_secs(),
266                    });
267                    if let (Some(src), Some(dst)) = (payload.as_object(), log.as_object_mut()) {
268                        for (k, v) in src {
269                            dst.entry(k.clone()).or_insert(v.clone());
270                        }
271                    }
272                    self.add_log(log);
273                }
274                "spawned" => {
275                    let mut log = json!({
276                        "type":      "spawned",
277                        "agent_id":  agent_id,
278                        "timestamp": now_secs(),
279                    });
280                    if let (Some(src), Some(dst)) = (payload.as_object(), log.as_object_mut()) {
281                        for (k, v) in src {
282                            dst.entry(k.clone()).or_insert(v.clone());
283                        }
284                    }
285                    self.add_log(log);
286                }
287                "completed" => {
288                    self.update_agent(agent_id, "last_completed", payload.clone());
289                    self.add_log(json!({
290                        "type":      "completed",
291                        "agent_id":  agent_id,
292                        "timestamp": now_secs(),
293                    }));
294                }
295                "alert" => {
296                    let short = &agent_id[..agent_id.len().min(8)];
297                    let known_name = self
298                        .agents
299                        .get(agent_id)
300                        .and_then(|a| a.get("name"))
301                        .and_then(|v| v.as_str())
302                        .unwrap_or(short)
303                        .to_string();
304                    let enriched = if let Some(obj) = payload.as_object() {
305                        let mut e = obj.clone();
306                        e.insert("agent_id".into(), json!(agent_id));
307                        e.entry("name".to_string())
308                            .or_insert_with(|| json!(&known_name));
309                        Value::Object(e)
310                    } else {
311                        json!({ "agent_id": agent_id })
312                    };
313                    let severity = enriched
314                        .get("severity")
315                        .and_then(|v| v.as_str())
316                        .unwrap_or("warning")
317                        .to_string();
318                    let name = enriched
319                        .get("name")
320                        .and_then(|v| v.as_str())
321                        .unwrap_or(&known_name)
322                        .to_string();
323                    self.alerts.insert(0, enriched);
324                    if self.alerts.len() > 50 {
325                        self.alerts.pop();
326                    }
327                    self.add_log(json!({
328                        "type":      "alert",
329                        "agent_id":  agent_id,
330                        "name":      name,
331                        "message":   format!("{name} unresponsive ({severity})"),
332                        "timestamp": now_secs(),
333                    }));
334                }
335                _ => {}
336            }
337            return Some((
338                json!({
339                    "type":     "agent",
340                    "agent_id": agent_id,
341                    "metric":   metric,
342                    "data":     payload,
343                }),
344                false,
345            ));
346        }
347
348        // ── nodes/{name}/heartbeat ───────────────────────────────────────────
349        if parts[0] == "nodes" && parts.len() >= 3 && parts[2] == "heartbeat" {
350            let node_name = parts[1];
351            if let Some(obj) = payload.as_object() {
352                self.nodes.insert(
353                    node_name.to_string(),
354                    json!({
355                        "node":      node_name,
356                        "agents":    obj.get("agents").cloned().unwrap_or(json!([])),
357                        "last_seen": now_secs(),
358                        "online":    true,
359                        "node_id":   obj.get("node_id").cloned().unwrap_or(json!("")),
360                    }),
361                );
362            }
363            return Some((
364                json!({
365                    "type":      "node",
366                    "node_name": node_name,
367                    "data":      payload,
368                }),
369                false,
370            ));
371        }
372
373        None
374    }
375}
376
377// ── Shared bridge state ───────────────────────────────────────────────────────
378
379#[derive(Clone)]
380pub struct BridgeState {
381    /// MQTT → WS broadcast (raw envelopes, consumed by monitor task).
382    pub mqtt_tx: broadcast::Sender<WsEnvelope>,
383    /// Aggregated monitor state shared across all `/ws` connections.
384    pub monitor: Arc<Mutex<MonitorState>>,
385    /// Broadcast channel: serialised JSON patches to all `/ws` clients.
386    pub monitor_tx: broadcast::Sender<String>,
387    /// MQTT client for publishing commands received from the browser.
388    pub mqtt_client: Arc<MqttClient>,
389    /// Live actor registry used for direct browser -> actor chat routing.
390    pub system: ActorSystem,
391    /// Mosquitto WebSocket host (for `/mqtt` proxy).
392    pub mqtt_host: String,
393    /// Mosquitto WebSocket port (for `/mqtt` proxy, default 9001).
394    pub mqtt_ws_port: u16,
395}
396
397// ── WsBridge ──────────────────────────────────────────────────────────────────
398
399pub struct WsBridge {
400    state: BridgeState,
401}
402
403impl WsBridge {
404    pub fn new(
405        mqtt_tx: broadcast::Sender<WsEnvelope>,
406        mqtt_client: Arc<MqttClient>,
407        system: ActorSystem,
408        mqtt_host: String,
409        mqtt_ws_port: u16,
410    ) -> Self {
411        let (monitor_tx, _) = broadcast::channel::<String>(256);
412        Self {
413            state: BridgeState {
414                mqtt_tx,
415                monitor: Arc::new(Mutex::new(MonitorState::default())),
416                monitor_tx,
417                mqtt_client,
418                system,
419                mqtt_host,
420                mqtt_ws_port,
421            },
422        }
423    }
424
425    /// Spawn a background task that:
426    ///
427    /// 1. Subscribes to `nodes/#` so remote-node heartbeats reach the bridge.
428    /// 2. Consumes raw MQTT envelopes from the broadcast channel.
429    /// 3. Updates [`MonitorState`].
430    /// 4. Broadcasts Python-compatible JSON patches to all `/ws` clients.
431    pub fn spawn_monitor_task(&self) {
432        // Subscribe to nodes/# so remote node heartbeats are received.
433        // agents/# and system/# are subscribed in main.rs; nodes/# is the
434        // bridge's own concern.
435        let mqtt_for_sub = Arc::clone(&self.state.mqtt_client);
436        tokio::spawn(async move {
437            if let Err(e) = mqtt_for_sub.subscribe("nodes/#").await {
438                tracing::warn!(
439                    "[ws-bridge] nodes/# subscribe failed (broker may not be running): {e}"
440                );
441            } else {
442                tracing::info!("[ws-bridge] subscribed to nodes/#");
443            }
444        });
445
446        let mut rx = self.state.mqtt_tx.subscribe();
447        let monitor = Arc::clone(&self.state.monitor);
448        let monitor_tx = self.state.monitor_tx.clone();
449
450        tokio::spawn(async move {
451            while let Ok(envelope) = rx.recv().await {
452                let msg = {
453                    let mut st = monitor.lock().await;
454                    match st.parse_topic(&envelope.topic, envelope.payload) {
455                        None => continue,
456                        Some((event, is_heartbeat)) => {
457                            let snap = st.snapshot();
458                            let log_event = if is_heartbeat { Value::Null } else { event };
459                            serde_json::to_string(&json!({
460                                "type":  "patch",
461                                "event": log_event,
462                                "state": snap,
463                            }))
464                            .unwrap_or_default()
465                        }
466                    }
467                };
468                if !msg.is_empty() {
469                    let _ = monitor_tx.send(msg);
470                }
471            }
472        });
473    }
474
475    /// Build the axum `Router` with `/ws` and `/mqtt` routes.
476    pub fn router(&self) -> Router {
477        Router::new()
478            .route("/ws", get(ws_handler))
479            .route("/mqtt", get(mqtt_proxy_handler))
480            .with_state(self.state.clone())
481    }
482}
483
484// ── /ws handler: Python-compatible aggregated state ───────────────────────────
485
486async fn ws_handler(ws: WebSocketUpgrade, State(state): State<BridgeState>) -> impl IntoResponse {
487    ws.on_upgrade(move |socket| handle_ws_socket(socket, state))
488}
489
490async fn handle_ws_socket(socket: WebSocket, state: BridgeState) {
491    let mut monitor_rx = state.monitor_tx.subscribe();
492    let (mut ws_send, mut ws_recv) = socket.split();
493
494    // Send a full state snapshot immediately on connect (mirrors Python behaviour)
495    let snap_json = {
496        let mut st = state.monitor.lock().await;
497        serde_json::to_string(&json!({
498            "type":  "full_snapshot",
499            "state": st.snapshot(),
500        }))
501        .unwrap_or_default()
502    };
503    if ws_send.send(Message::Text(snap_json.into())).await.is_err() {
504        return;
505    }
506    let config_json = serde_json::to_string(&json!({
507        "type": "config",
508        "chat_mode": "direct_ws",
509    }))
510    .unwrap_or_default();
511    if ws_send
512        .send(Message::Text(config_json.into()))
513        .await
514        .is_err()
515    {
516        return;
517    }
518
519    // Per-client direct-reply channel: slash command responses bypass the
520    // broadcast and go only to this specific connection.
521    let (reply_tx, mut reply_rx) = mpsc::channel::<String>(32);
522
523    // Send task: merges broadcast patches and per-client direct replies.
524    let send_task = tokio::spawn(async move {
525        loop {
526            tokio::select! {
527                Ok(json) = monitor_rx.recv() => {
528                    if ws_send.send(Message::Text(json.into())).await.is_err() {
529                        break;
530                    }
531                }
532                Some(json) = reply_rx.recv() => {
533                    if ws_send.send(Message::Text(json.into())).await.is_err() {
534                        break;
535                    }
536                }
537                else => break,
538            }
539        }
540    });
541
542    // Handle inbound messages (commands and slash commands from the browser)
543    while let Some(Ok(msg)) = ws_recv.next().await {
544        match msg {
545            Message::Close(_) => break,
546            Message::Text(text) => {
547                let trimmed = text.trim();
548                if trimmed.starts_with('/') {
549                    // Slash command — reply only to this client
550                    let reply = handle_slash_command(trimmed, &state).await;
551                    let _ = reply_tx.send(reply).await;
552                } else {
553                    handle_browser_message(trimmed, &state).await;
554                }
555            }
556            _ => {}
557        }
558    }
559    send_task.abort();
560}
561
562/// Handle a slash command sent by a browser client over `/ws`.
563///
564/// Mirrors the Python `handle_slash` dispatcher in `monitor_server.py`.
565/// Returns a JSON string to send back to that specific client only.
566async fn handle_slash_command(text: &str, state: &BridgeState) -> String {
567    let parts: Vec<&str> = text.split_whitespace().collect();
568    let cmd = parts.first().map(|s| s.to_lowercase()).unwrap_or_default();
569
570    let content = match cmd.as_str() {
571        "/help" | "/h" => "Commands:\n\
572             \x20 /agents                        list all active agents\n\
573             \x20 /nodes                         list remote nodes\n\
574             \x20 /help                          show this help\n\n\
575             Everything else is forwarded to the main orchestrator."
576            .to_string(),
577
578        "/agents" => {
579            let st = state.monitor.lock().await;
580            if st.agents.is_empty() {
581                "No agents running.".to_string()
582            } else {
583                let mut lines = vec!["Agents:".to_string()];
584                let mut names: Vec<&str> = st
585                    .agents
586                    .values()
587                    .filter_map(|a| a.get("name").and_then(|v| v.as_str()))
588                    .collect();
589                names.sort_unstable();
590                for name in names {
591                    // Find the full agent entry for this name
592                    let entry = st
593                        .agents
594                        .values()
595                        .find(|a| a.get("name").and_then(|v| v.as_str()) == Some(name));
596                    let state_str = entry
597                        .and_then(|a| a.get("state"))
598                        .and_then(|v| v.as_str())
599                        .unwrap_or("?");
600                    let agent_id = entry
601                        .and_then(|a| a.get("agent_id"))
602                        .and_then(|v| v.as_str())
603                        .unwrap_or("");
604                    let id_short = &agent_id[..agent_id.len().min(8)];
605                    lines.push(format!("  [{state_str:8}] @{name:<22} {id_short}"));
606                }
607                lines.join("\n")
608            }
609        }
610
611        "/nodes" => {
612            let st = state.monitor.lock().await;
613            let mut lines = vec!["Nodes:".to_string()];
614            if st.nodes.is_empty() {
615                lines.push("  (no remote nodes)".to_string());
616            } else {
617                let mut node_names: Vec<&str> = st.nodes.keys().map(|s| s.as_str()).collect();
618                node_names.sort_unstable();
619                for node_name in node_names {
620                    if let Some(nd) = st.nodes.get(node_name) {
621                        let online = nd.get("online").and_then(|v| v.as_bool()).unwrap_or(false);
622                        let status = if online { "online" } else { "OFFLINE" };
623                        let agents: Vec<String> = nd
624                            .get("agents")
625                            .and_then(|v| v.as_array())
626                            .map(|arr| {
627                                arr.iter()
628                                    .filter_map(|v| v.as_str())
629                                    .map(|s| format!("@{s}"))
630                                    .collect()
631                            })
632                            .unwrap_or_default();
633                        let agent_list = if agents.is_empty() {
634                            "(no agents)".to_string()
635                        } else {
636                            agents.join(", ")
637                        };
638                        lines.push(format!("  {node_name:<20} {status:<6}   {agent_list}"));
639                    }
640                }
641            }
642            lines.join("\n")
643        }
644
645        _ => format!("Unknown command: {cmd}. Type /help for available commands."),
646    };
647
648    serde_json::to_string(&json!({
649        "type":      "chat",
650        "from":      "monitor",
651        "content":   content,
652        "timestamp": now_secs(),
653    }))
654    .unwrap_or_default()
655}
656
657async fn handle_browser_message(text: &str, state: &BridgeState) {
658    let Ok(cmd) = serde_json::from_str::<Value>(text) else {
659        return;
660    };
661    match cmd.get("type").and_then(|v| v.as_str()) {
662        Some("command") => handle_browser_command(cmd, state).await,
663        Some("chat") => handle_browser_chat(cmd, state).await,
664        _ => {}
665    }
666}
667
668async fn handle_browser_chat(cmd: Value, state: &BridgeState) {
669    let Some(content) = cmd.get("content").and_then(|v| v.as_str()) else {
670        return;
671    };
672    let target_name = cmd
673        .get("agent_name")
674        .and_then(|v| v.as_str())
675        .filter(|s| !s.is_empty())
676        .unwrap_or("main-actor");
677
678    let Some(entry) = state.system.registry.get_by_name(target_name).await else {
679        tracing::warn!("[ws] chat target not found: {target_name}");
680        return;
681    };
682    let msg = ActorMessage::text(
683        Some("user".to_string()),
684        Some(entry.id.clone()),
685        content.to_string(),
686    );
687    if let Err(err) = state.system.registry.send(&entry.id, msg).await {
688        tracing::warn!("[ws] chat route failed for {target_name}: {err}");
689    }
690}
691
692async fn handle_browser_command(cmd: Value, state: &BridgeState) {
693    let Some(command) = cmd.get("command").and_then(|v| v.as_str()) else {
694        return;
695    };
696    let Some(agent_id) = cmd.get("agent_id").and_then(|v| v.as_str()) else {
697        return;
698    };
699
700    let valid = ["pause", "stop", "resume", "delete"];
701    if !valid.contains(&command) {
702        tracing::warn!("[ws] Unknown command: {command}");
703        return;
704    }
705
706    tracing::info!(
707        "[ws] {} -> {}",
708        command.to_uppercase(),
709        &agent_id[..agent_id.len().min(8)]
710    );
711
712    // Publish command to MQTT
713    let mqtt_payload = json!({
714        "command":   command,
715        "sender":    "monitor-dashboard",
716        "timestamp": now_secs(),
717    });
718    let topic = format!("agents/{agent_id}/commands");
719    if let Err(e) = state.mqtt_client.publish_json(&topic, &mqtt_payload).await {
720        tracing::error!("[ws] MQTT publish failed: {e}");
721        return;
722    }
723
724    // Optimistic state update + broadcast
725    let msg = {
726        let mut st = state.monitor.lock().await;
727        if command == "delete" {
728            st.agents.remove(agent_id);
729            let snap = st.snapshot();
730            serde_json::to_string(&json!({
731                "type":     "delete_agent",
732                "agent_id": agent_id,
733                "state":    snap,
734            }))
735            .unwrap_or_default()
736        } else {
737            if let Some(entry) = st.agents.get_mut(agent_id)
738                && let Some(e) = entry.as_object_mut()
739            {
740                let new_state = match command {
741                    "stop" => "stopped",
742                    "pause" => "paused",
743                    "resume" => "running",
744                    _ => return,
745                };
746                e.insert("state".into(), json!(new_state));
747            }
748            let snap = st.snapshot();
749            serde_json::to_string(&json!({
750                "type":  "patch",
751                "state": snap,
752            }))
753            .unwrap_or_default()
754        }
755    };
756
757    if !msg.is_empty() {
758        let _ = state.monitor_tx.send(msg);
759    }
760}
761
762// ── /mqtt handler: transparent proxy to Mosquitto WS ─────────────────────────
763//
764// The browser's mqtt.js speaks the MQTT binary protocol over WebSocket.
765// We forward every frame verbatim to/from Mosquitto's WS listener (port 9001
766// by default, or whatever --mqtt-ws-port is set to).
767//
768// Supports the "mqtt" subprotocol header so mqtt.js is satisfied.
769
770async fn mqtt_proxy_handler(
771    ws: WebSocketUpgrade,
772    headers: HeaderMap,
773    State(state): State<BridgeState>,
774) -> impl IntoResponse {
775    // Echo back whichever MQTT sub-protocol the client announced
776    let proto = headers
777        .get("sec-websocket-protocol")
778        .and_then(|v| v.to_str().ok())
779        .map(|s| s.to_string());
780
781    let ws = ws.protocols(["mqtt", "mqttv3.1"]);
782    ws.on_upgrade(move |socket| proxy_to_mosquitto(socket, state, proto))
783}
784
785async fn proxy_to_mosquitto(socket: WebSocket, state: BridgeState, proto: Option<String>) {
786    use tokio_tungstenite::connect_async;
787    use tokio_tungstenite::tungstenite::Message as TMsg;
788    use tokio_tungstenite::tungstenite::client::IntoClientRequest;
789
790    let upstream_url = format!("ws://{}:{}/", state.mqtt_host, state.mqtt_ws_port);
791
792    // Build a proper client handshake request, then add the MQTT sub-protocol.
793    let request = {
794        let mut request = match upstream_url.as_str().into_client_request() {
795            Ok(r) => r,
796            Err(e) => {
797                tracing::warn!("[mqtt-proxy] bad upstream request: {e}");
798                return;
799            }
800        };
801        let p = proto.as_deref().unwrap_or("mqtt");
802        if let Ok(value) = p.parse() {
803            request
804                .headers_mut()
805                .insert("Sec-WebSocket-Protocol", value);
806        }
807        request
808    };
809
810    let upstream = match connect_async(request).await {
811        Ok((stream, _)) => stream,
812        Err(e) => {
813            tracing::warn!(
814                "[mqtt-proxy] upstream connect failed ({}): {e}",
815                upstream_url
816            );
817            return;
818        }
819    };
820
821    let (mut up_send, mut up_recv) = upstream.split();
822    let (mut cl_send, mut cl_recv) = socket.split();
823
824    // upstream → client
825    let up_to_cl = tokio::spawn(async move {
826        while let Some(Ok(msg)) = up_recv.next().await {
827            let out = match msg {
828                TMsg::Binary(b) => Message::Binary(b),
829                TMsg::Text(t) => Message::Text(t.as_str().into()),
830                TMsg::Close(_) => break,
831                _ => continue,
832            };
833            if cl_send.send(out).await.is_err() {
834                break;
835            }
836        }
837    });
838
839    // client → upstream
840    while let Some(Ok(msg)) = cl_recv.next().await {
841        let fwd = match msg {
842            Message::Binary(b) => TMsg::Binary(b),
843            Message::Text(t) => TMsg::Text(t.as_str().into()),
844            Message::Close(_) => break,
845            _ => continue,
846        };
847        if up_send.send(fwd).await.is_err() {
848            break;
849        }
850    }
851
852    up_to_cl.abort();
853}