1use anyhow::Result;
12use async_trait::async_trait;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15
16use crate::llm_agent::{LlmAgent, LlmConfig};
17use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
18
19pub struct HomeAssistantAgent {
21 config: ActorConfig,
22 ha_url: String,
23 ha_token: String,
24 http: reqwest::Client,
25 llm: Option<LlmAgent>,
26 state: ActorState,
27 metrics: Arc<ActorMetrics>,
28 mailbox_tx: mpsc::Sender<Message>,
29 mailbox_rx: Option<mpsc::Receiver<Message>>,
30 publisher: Option<EventPublisher>,
31}
32
33impl HomeAssistantAgent {
34 pub fn new(config: ActorConfig) -> Self {
35 let ha_url = std::env::var("HA_URL").unwrap_or_default();
36 let ha_token = std::env::var("HA_TOKEN").unwrap_or_default();
37 let (tx, rx) = mpsc::channel(config.mailbox_capacity);
38 Self {
39 config,
40 ha_url,
41 ha_token,
42 http: reqwest::Client::new(),
43 llm: None,
44 state: ActorState::Initializing,
45 metrics: Arc::new(ActorMetrics::new()),
46 mailbox_tx: tx,
47 mailbox_rx: Some(rx),
48 publisher: None,
49 }
50 }
51
52 pub fn with_publisher(mut self, p: EventPublisher) -> Self {
53 self.publisher = Some(p);
54 self
55 }
56
57 pub fn with_ha_config(mut self, url: impl Into<String>, token: impl Into<String>) -> Self {
59 let url = url.into();
60 let token = token.into();
61 if !url.is_empty() {
62 self.ha_url = url;
63 }
64 if !token.is_empty() {
65 self.ha_token = token;
66 }
67 self
68 }
69
70 pub fn with_llm(mut self, llm_config: LlmConfig) -> Self {
71 let llm_cfg = ActorConfig::new(format!("{}-llm", self.config.name));
72 self.llm = Some(LlmAgent::new(llm_cfg, llm_config));
73 self
74 }
75
76 async fn get_states(&self) -> Result<serde_json::Value> {
78 let resp = self
79 .http
80 .get(format!("{}/api/states", self.ha_url))
81 .header("Authorization", format!("Bearer {}", self.ha_token))
82 .header("Content-Type", "application/json")
83 .send()
84 .await?;
85 Ok(resp.json().await?)
86 }
87
88 async fn get_state(&self, entity_id: &str) -> Result<serde_json::Value> {
90 let resp = self
91 .http
92 .get(format!("{}/api/states/{}", self.ha_url, entity_id))
93 .header("Authorization", format!("Bearer {}", self.ha_token))
94 .send()
95 .await?;
96 Ok(resp.json().await?)
97 }
98
99 #[expect(dead_code)]
101 async fn call_service(
102 &self,
103 domain: &str,
104 service: &str,
105 data: serde_json::Value,
106 ) -> Result<serde_json::Value> {
107 let resp = self
108 .http
109 .post(format!(
110 "{}/api/services/{}/{}",
111 self.ha_url, domain, service
112 ))
113 .header("Authorization", format!("Bearer {}", self.ha_token))
114 .header("Content-Type", "application/json")
115 .json(&data)
116 .send()
117 .await?;
118 Ok(resp.json().await?)
119 }
120
121 async fn process_request(&mut self, text: &str) -> String {
122 let lower = text.to_lowercase();
124
125 if lower.contains("states") || lower.contains("all entities") {
126 match self.get_states().await {
127 Ok(v) => format!(
128 "HA states: {}",
129 serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string())
130 ),
131 Err(e) => format!("HA error: {e}"),
132 }
133 } else if let Some(entity) = extract_entity_id(text) {
134 match self.get_state(&entity).await {
135 Ok(v) => format!("{entity}: {}", v["state"].as_str().unwrap_or("unknown")),
136 Err(e) => format!("HA error: {e}"),
137 }
138 } else if let Some(llm) = &mut self.llm {
139 let prompt = format!(
140 "You are a Home Assistant expert. The user said: \"{text}\"\n\
141 Interpret this as a HA request and respond helpfully. \
142 If you need to call a service, suggest: call_service(domain, service, {{data}})."
143 );
144 llm.complete(&prompt)
145 .await
146 .unwrap_or_else(|e| format!("LLM error: {e}"))
147 } else {
148 "I can query HA entity states. Try: 'get state light.living_room' or 'list all states'"
149 .into()
150 }
151 }
152
153 fn now_ms() -> u64 {
154 std::time::SystemTime::now()
155 .duration_since(std::time::UNIX_EPOCH)
156 .unwrap_or_default()
157 .as_millis() as u64
158 }
159}
160
161fn extract_entity_id(text: &str) -> Option<String> {
162 let words: Vec<&str> = text.split_whitespace().collect();
164 words
165 .iter()
166 .find(|w| w.contains('.') && !w.starts_with("http"))
167 .map(|s| s.to_string())
168}
169
170#[async_trait]
171impl Actor for HomeAssistantAgent {
172 fn id(&self) -> String {
173 self.config.id.clone()
174 }
175 fn name(&self) -> &str {
176 &self.config.name
177 }
178 fn state(&self) -> ActorState {
179 self.state.clone()
180 }
181 fn metrics(&self) -> Arc<ActorMetrics> {
182 Arc::clone(&self.metrics)
183 }
184 fn mailbox(&self) -> mpsc::Sender<Message> {
185 self.mailbox_tx.clone()
186 }
187
188 async fn on_start(&mut self) -> Result<()> {
189 self.state = ActorState::Running;
190 let connected = !self.ha_url.is_empty() && !self.ha_token.is_empty();
191 tracing::info!(
192 "[{}] HA agent started (connected={})",
193 self.config.name,
194 connected
195 );
196 if let Some(pub_) = &self.publisher {
197 pub_.publish(
198 wactorz_mqtt::topics::spawn(&self.config.id),
199 &serde_json::json!({
200 "agentId": self.config.id,
201 "agentName": self.config.name,
202 "agentType": "home_assistant",
203 "haConnected": connected,
204 "timestampMs": Self::now_ms(),
205 }),
206 );
207 }
208 Ok(())
209 }
210
211 async fn handle_message(&mut self, message: Message) -> Result<()> {
212 use wactorz_core::message::MessageType;
213 let text = match &message.payload {
214 MessageType::Text { content } => content.clone(),
215 MessageType::Task { description, .. } => description.clone(),
216 _ => return Ok(()),
217 };
218 let response = self.process_request(&text).await;
219 if let Some(pub_) = &self.publisher {
220 pub_.publish(
221 wactorz_mqtt::topics::chat(&self.config.id),
222 &serde_json::json!({
223 "from": self.config.name,
224 "to": message.from.as_deref().unwrap_or("user"),
225 "content": response,
226 "timestampMs": Self::now_ms(),
227 }),
228 );
229 }
230 Ok(())
231 }
232
233 async fn on_heartbeat(&mut self) -> Result<()> {
234 if let Some(pub_) = &self.publisher {
235 pub_.publish(
236 wactorz_mqtt::topics::heartbeat(&self.config.id),
237 &serde_json::json!({
238 "agentId": self.config.id,
239 "agentName": self.config.name,
240 "state": self.state,
241 "haUrl": self.ha_url,
242 "timestampMs": Self::now_ms(),
243 }),
244 );
245 }
246 Ok(())
247 }
248
249 async fn run(&mut self) -> Result<()> {
250 self.on_start().await?;
251 let mut rx = self
252 .mailbox_rx
253 .take()
254 .ok_or_else(|| anyhow::anyhow!("HomeAssistantAgent already running"))?;
255 let mut hb = tokio::time::interval(std::time::Duration::from_secs(
256 self.config.heartbeat_interval_secs,
257 ));
258 hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
259 loop {
260 tokio::select! {
261 biased;
262 msg = rx.recv() => {
263 match msg {
264 None => break,
265 Some(m) => {
266 self.metrics.record_received();
267 if let wactorz_core::message::MessageType::Command {
268 command: wactorz_core::message::ActorCommand::Stop
269 } = &m.payload { break; }
270 match self.handle_message(m).await {
271 Ok(_) => self.metrics.record_processed(),
272 Err(e) => { tracing::error!("[{}] {e}", self.config.name); self.metrics.record_failed(); }
273 }
274 }
275 }
276 }
277 _ = hb.tick() => {
278 self.metrics.record_heartbeat();
279 if let Err(e) = self.on_heartbeat().await {
280 tracing::error!("[{}] heartbeat: {e}", self.config.name);
281 }
282 }
283 }
284 }
285 self.state = ActorState::Stopped;
286 self.on_stop().await
287 }
288}