wactorz_agents/
home_assistant_agent.rs

1//! Home Assistant integration agent.
2//!
3//! [`HomeAssistantAgent`] connects to a local Home Assistant instance via
4//! its REST API and WebSocket event bus.  It can query entity states, call
5//! services, and subscribe to state-change events.
6//!
7//! Configuration is read from environment variables:
8//! - `HA_URL`   — Home Assistant base URL (e.g. `http://homeassistant.local:8123`)
9//! - `HA_TOKEN` — Long-lived access token
10
11use anyhow::Result;
12use async_trait::async_trait;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15
16use crate::llm_agent::{LlmAgent, LlmConfig};
17use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
18
19/// Home Assistant agent.
20pub struct HomeAssistantAgent {
21    config: ActorConfig,
22    ha_url: String,
23    ha_token: String,
24    http: reqwest::Client,
25    llm: Option<LlmAgent>,
26    state: ActorState,
27    metrics: Arc<ActorMetrics>,
28    mailbox_tx: mpsc::Sender<Message>,
29    mailbox_rx: Option<mpsc::Receiver<Message>>,
30    publisher: Option<EventPublisher>,
31}
32
33impl HomeAssistantAgent {
34    pub fn new(config: ActorConfig) -> Self {
35        let ha_url = std::env::var("HA_URL").unwrap_or_default();
36        let ha_token = std::env::var("HA_TOKEN").unwrap_or_default();
37        let (tx, rx) = mpsc::channel(config.mailbox_capacity);
38        Self {
39            config,
40            ha_url,
41            ha_token,
42            http: reqwest::Client::new(),
43            llm: None,
44            state: ActorState::Initializing,
45            metrics: Arc::new(ActorMetrics::new()),
46            mailbox_tx: tx,
47            mailbox_rx: Some(rx),
48            publisher: None,
49        }
50    }
51
52    pub fn with_publisher(mut self, p: EventPublisher) -> Self {
53        self.publisher = Some(p);
54        self
55    }
56
57    /// Override the HA URL and token instead of relying on environment variables.
58    pub fn with_ha_config(mut self, url: impl Into<String>, token: impl Into<String>) -> Self {
59        let url = url.into();
60        let token = token.into();
61        if !url.is_empty() {
62            self.ha_url = url;
63        }
64        if !token.is_empty() {
65            self.ha_token = token;
66        }
67        self
68    }
69
70    pub fn with_llm(mut self, llm_config: LlmConfig) -> Self {
71        let llm_cfg = ActorConfig::new(format!("{}-llm", self.config.name));
72        self.llm = Some(LlmAgent::new(llm_cfg, llm_config));
73        self
74    }
75
76    /// GET /api/states — return all entity states as JSON.
77    async fn get_states(&self) -> Result<serde_json::Value> {
78        let resp = self
79            .http
80            .get(format!("{}/api/states", self.ha_url))
81            .header("Authorization", format!("Bearer {}", self.ha_token))
82            .header("Content-Type", "application/json")
83            .send()
84            .await?;
85        Ok(resp.json().await?)
86    }
87
88    /// GET /api/states/<entity_id> — single entity state.
89    async fn get_state(&self, entity_id: &str) -> Result<serde_json::Value> {
90        let resp = self
91            .http
92            .get(format!("{}/api/states/{}", self.ha_url, entity_id))
93            .header("Authorization", format!("Bearer {}", self.ha_token))
94            .send()
95            .await?;
96        Ok(resp.json().await?)
97    }
98
99    /// POST /api/services/<domain>/<service> — call a HA service.
100    #[expect(dead_code)]
101    async fn call_service(
102        &self,
103        domain: &str,
104        service: &str,
105        data: serde_json::Value,
106    ) -> Result<serde_json::Value> {
107        let resp = self
108            .http
109            .post(format!(
110                "{}/api/services/{}/{}",
111                self.ha_url, domain, service
112            ))
113            .header("Authorization", format!("Bearer {}", self.ha_token))
114            .header("Content-Type", "application/json")
115            .json(&data)
116            .send()
117            .await?;
118        Ok(resp.json().await?)
119    }
120
121    async fn process_request(&mut self, text: &str) -> String {
122        // Simple keyword dispatch; LLM interprets if available
123        let lower = text.to_lowercase();
124
125        if lower.contains("states") || lower.contains("all entities") {
126            match self.get_states().await {
127                Ok(v) => format!(
128                    "HA states: {}",
129                    serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string())
130                ),
131                Err(e) => format!("HA error: {e}"),
132            }
133        } else if let Some(entity) = extract_entity_id(text) {
134            match self.get_state(&entity).await {
135                Ok(v) => format!("{entity}: {}", v["state"].as_str().unwrap_or("unknown")),
136                Err(e) => format!("HA error: {e}"),
137            }
138        } else if let Some(llm) = &mut self.llm {
139            let prompt = format!(
140                "You are a Home Assistant expert. The user said: \"{text}\"\n\
141                 Interpret this as a HA request and respond helpfully. \
142                 If you need to call a service, suggest: call_service(domain, service, {{data}})."
143            );
144            llm.complete(&prompt)
145                .await
146                .unwrap_or_else(|e| format!("LLM error: {e}"))
147        } else {
148            "I can query HA entity states. Try: 'get state light.living_room' or 'list all states'"
149                .into()
150        }
151    }
152
153    fn now_ms() -> u64 {
154        std::time::SystemTime::now()
155            .duration_since(std::time::UNIX_EPOCH)
156            .unwrap_or_default()
157            .as_millis() as u64
158    }
159}
160
161fn extract_entity_id(text: &str) -> Option<String> {
162    // Look for patterns like "light.living_room", "sensor.temperature", etc.
163    let words: Vec<&str> = text.split_whitespace().collect();
164    words
165        .iter()
166        .find(|w| w.contains('.') && !w.starts_with("http"))
167        .map(|s| s.to_string())
168}
169
170#[async_trait]
171impl Actor for HomeAssistantAgent {
172    fn id(&self) -> String {
173        self.config.id.clone()
174    }
175    fn name(&self) -> &str {
176        &self.config.name
177    }
178    fn state(&self) -> ActorState {
179        self.state.clone()
180    }
181    fn metrics(&self) -> Arc<ActorMetrics> {
182        Arc::clone(&self.metrics)
183    }
184    fn mailbox(&self) -> mpsc::Sender<Message> {
185        self.mailbox_tx.clone()
186    }
187
188    async fn on_start(&mut self) -> Result<()> {
189        self.state = ActorState::Running;
190        let connected = !self.ha_url.is_empty() && !self.ha_token.is_empty();
191        tracing::info!(
192            "[{}] HA agent started (connected={})",
193            self.config.name,
194            connected
195        );
196        if let Some(pub_) = &self.publisher {
197            pub_.publish(
198                wactorz_mqtt::topics::spawn(&self.config.id),
199                &serde_json::json!({
200                    "agentId":   self.config.id,
201                    "agentName": self.config.name,
202                    "agentType": "home_assistant",
203                    "haConnected": connected,
204                    "timestampMs": Self::now_ms(),
205                }),
206            );
207        }
208        Ok(())
209    }
210
211    async fn handle_message(&mut self, message: Message) -> Result<()> {
212        use wactorz_core::message::MessageType;
213        let text = match &message.payload {
214            MessageType::Text { content } => content.clone(),
215            MessageType::Task { description, .. } => description.clone(),
216            _ => return Ok(()),
217        };
218        let response = self.process_request(&text).await;
219        if let Some(pub_) = &self.publisher {
220            pub_.publish(
221                wactorz_mqtt::topics::chat(&self.config.id),
222                &serde_json::json!({
223                    "from":        self.config.name,
224                    "to":          message.from.as_deref().unwrap_or("user"),
225                    "content":     response,
226                    "timestampMs": Self::now_ms(),
227                }),
228            );
229        }
230        Ok(())
231    }
232
233    async fn on_heartbeat(&mut self) -> Result<()> {
234        if let Some(pub_) = &self.publisher {
235            pub_.publish(
236                wactorz_mqtt::topics::heartbeat(&self.config.id),
237                &serde_json::json!({
238                    "agentId":   self.config.id,
239                    "agentName": self.config.name,
240                    "state":     self.state,
241                    "haUrl":     self.ha_url,
242                    "timestampMs": Self::now_ms(),
243                }),
244            );
245        }
246        Ok(())
247    }
248
249    async fn run(&mut self) -> Result<()> {
250        self.on_start().await?;
251        let mut rx = self
252            .mailbox_rx
253            .take()
254            .ok_or_else(|| anyhow::anyhow!("HomeAssistantAgent already running"))?;
255        let mut hb = tokio::time::interval(std::time::Duration::from_secs(
256            self.config.heartbeat_interval_secs,
257        ));
258        hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
259        loop {
260            tokio::select! {
261                biased;
262                msg = rx.recv() => {
263                    match msg {
264                        None => break,
265                        Some(m) => {
266                            self.metrics.record_received();
267                            if let wactorz_core::message::MessageType::Command {
268                                command: wactorz_core::message::ActorCommand::Stop
269                            } = &m.payload { break; }
270                            match self.handle_message(m).await {
271                                Ok(_)  => self.metrics.record_processed(),
272                                Err(e) => { tracing::error!("[{}] {e}", self.config.name); self.metrics.record_failed(); }
273                            }
274                        }
275                    }
276                }
277                _ = hb.tick() => {
278                    self.metrics.record_heartbeat();
279                    if let Err(e) = self.on_heartbeat().await {
280                        tracing::error!("[{}] heartbeat: {e}", self.config.name);
281                    }
282                }
283            }
284        }
285        self.state = ActorState::Stopped;
286        self.on_stop().await
287    }
288}