Coverage for connection.py: 99%

186 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-02 12:06 +0000

1# -*- coding: utf-8 -*- 

2# ------------------------------------------------------------------------------ 

3# 

4# Copyright 2022 Valory AG 

5# Copyright 2018-2021 Fetch.AI Limited 

6# 

7# Licensed under the Apache License, Version 2.0 (the "License"); 

8# you may not use this file except in compliance with the License. 

9# You may obtain a copy of the License at 

10# 

11# http://www.apache.org/licenses/LICENSE-2.0 

12# 

13# Unless required by applicable law or agreed to in writing, software 

14# distributed under the License is distributed on an "AS IS" BASIS, 

15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

16# See the License for the specific language governing permissions and 

17# limitations under the License. 

18# 

19# ------------------------------------------------------------------------------ 

20"""Extension to the Local Node.""" 

21import asyncio 

22import logging 

23import threading 

24from asyncio import AbstractEventLoop, Queue 

25from collections import defaultdict 

26from concurrent.futures import Future 

27from threading import Thread 

28from typing import Any, Dict, List, Optional, Tuple, cast 

29 

30from aea.common import Address 

31from aea.configurations.base import PublicId 

32from aea.connections.base import Connection, ConnectionStates 

33from aea.helpers.search.models import Description 

34from aea.mail.base import Envelope 

35from aea.protocols.base import Message 

36from aea.protocols.dialogue.base import Dialogue as BaseDialogue 

37 

38from packages.fetchai.protocols.default.message import DefaultMessage 

39from packages.fetchai.protocols.oef_search.dialogues import ( 

40 OefSearchDialogue as BaseOefSearchDialogue, 

41) 

42from packages.fetchai.protocols.oef_search.dialogues import ( 

43 OefSearchDialogues as BaseOefSearchDialogues, 

44) 

45from packages.fetchai.protocols.oef_search.message import OefSearchMessage 

46 

47 

48_default_logger = logging.getLogger("aea.packages.fetchai.connections.local") 

49 

50TARGET = 0 

51MESSAGE_ID = 1 

52RESPONSE_TARGET = MESSAGE_ID 

53RESPONSE_MESSAGE_ID = MESSAGE_ID + 1 

54STUB_DIALOGUE_ID = 0 

55 

56PUBLIC_ID = PublicId.from_str("fetchai/local:0.20.0") 

57 

58 

59OefSearchDialogue = BaseOefSearchDialogue 

60OEF_LOCAL_NODE_SEARCH_ADDRESS = "oef_local_node_search" 

61OEF_LOCAL_NODE_ADDRESS = "oef_local_node" 

62 

63 

64class OefSearchDialogues(BaseOefSearchDialogues): 

65 """The dialogues class keeps track of all dialogues.""" 

66 

67 def __init__(self) -> None: 

68 """Initialize dialogues.""" 

69 

70 def role_from_first_message( # pylint: disable=unused-argument 

71 message: Message, receiver_address: Address 

72 ) -> BaseDialogue.Role: 

73 """Infer the role of the agent from an incoming/outgoing first message 

74 

75 :param message: an incoming/outgoing first message 

76 :param receiver_address: the address of the receiving agent 

77 :return: The role of the agent 

78 """ 

79 # The local connection maintains the dialogue on behalf of the node 

80 return OefSearchDialogue.Role.OEF_NODE 

81 

82 BaseOefSearchDialogues.__init__( 

83 self, 

84 self_address=OEF_LOCAL_NODE_SEARCH_ADDRESS, 

85 role_from_first_message=role_from_first_message, 

86 dialogue_class=OefSearchDialogue, 

87 ) 

88 

89 

90class LocalNode: 

91 """A light-weight local implementation of a OEF Node.""" 

92 

93 def __init__( 

94 self, loop: AbstractEventLoop = None, logger: logging.Logger = _default_logger 

95 ): 

96 """ 

97 Initialize a local (i.e. non-networked) implementation of an OEF Node. 

98 

99 :param loop: the event loop. If None, a new event loop is instantiated. 

100 :param logger: the logger. 

101 """ 

102 self._lock = threading.Lock() 

103 self.services = defaultdict(lambda: []) # type: Dict[str, List[Description]] 

104 self._loop = loop if loop is not None else asyncio.new_event_loop() 

105 self._thread = Thread(target=self._run_loop, daemon=True) 

106 

107 self._in_queue = None # type: Optional[asyncio.Queue] 

108 self._out_queues = {} # type: Dict[str, asyncio.Queue] 

109 

110 self._receiving_loop_task = None # type: Optional[Future] 

111 self.address: Optional[Address] = None 

112 self._dialogues: Optional[OefSearchDialogues] = None 

113 self.logger = logger 

114 self.started_event = threading.Event() 

115 

116 def __enter__(self) -> "LocalNode": 

117 """Start the local node.""" 

118 self.start() 

119 return self 

120 

121 def __exit__(self, exc_type: str, exc_val: str, exc_tb: str) -> None: 

122 """Stop the local node.""" 

123 self.stop() 

124 

125 def _run_loop(self) -> None: 

126 """ 

127 Run the asyncio loop. 

128 

129 This method is supposed to be run only in the Multiplexer thread. 

130 """ 

131 self.logger.debug("Starting threaded asyncio loop...") 

132 asyncio.set_event_loop(self._loop) 

133 self._loop.run_forever() 

134 self.logger.debug("Asyncio loop has been stopped.") 

135 

136 async def connect( 

137 self, address: Address, writer: asyncio.Queue 

138 ) -> Optional[asyncio.Queue]: 

139 """ 

140 Connect an address to the node. 

141 

142 :param address: the address of the agent. 

143 :param writer: the queue where the client is listening. 

144 :return: an asynchronous queue, that constitutes the communication channel. 

145 """ 

146 if address in self._out_queues.keys(): 

147 return None 

148 

149 if self._in_queue is None: # pragma: nocover 

150 raise ValueError("In queue not set.") 

151 q = self._in_queue # type: asyncio.Queue 

152 self._out_queues[address] = writer 

153 

154 self.address = address 

155 self._dialogues = OefSearchDialogues() 

156 return q 

157 

158 def start(self) -> None: 

159 """Start the node.""" 

160 if not self._loop.is_running() and not self._thread.is_alive(): 

161 self._thread.start() 

162 self._receiving_loop_task = asyncio.run_coroutine_threadsafe( 

163 self.receiving_loop(), loop=self._loop 

164 ) 

165 self.started_event.wait() 

166 self.logger.debug("Local node has been started.") 

167 

168 def stop(self) -> None: 

169 """Stop the node.""" 

170 

171 if self._receiving_loop_task is None or self._in_queue is None: 

172 raise ValueError("Connection not started!") 

173 asyncio.run_coroutine_threadsafe(self._in_queue.put(None), self._loop).result() 

174 self._receiving_loop_task.result() 

175 

176 if self._loop.is_running(): 

177 self._loop.call_soon_threadsafe(self._loop.stop) 

178 if self._thread.is_alive(): 

179 self._thread.join() 

180 

181 async def receiving_loop(self) -> None: 

182 """Process incoming messages.""" 

183 self._in_queue = asyncio.Queue() 

184 self.started_event.set() 

185 while True: 

186 envelope = await self._in_queue.get() 

187 if envelope is None: 

188 self.logger.debug("Receiving loop terminated.") 

189 return 

190 self.logger.debug("Handling envelope: {}".format(envelope)) 

191 await self._handle_envelope(envelope) 

192 

193 async def _handle_envelope(self, envelope: Envelope) -> None: 

194 """Handle an envelope. 

195 

196 :param envelope: the envelope 

197 """ 

198 if ( 

199 envelope.protocol_specification_id 

200 == OefSearchMessage.protocol_specification_id 

201 ): 

202 await self._handle_oef_message(envelope) 

203 else: 

204 OEFLocalConnection._ensure_valid_envelope_for_external_comms( # pylint: disable=protected-access 

205 envelope 

206 ) 

207 await self._handle_agent_message(envelope) 

208 

209 async def _handle_oef_message(self, envelope: Envelope) -> None: 

210 """Handle oef messages. 

211 

212 :param envelope: the envelope 

213 """ 

214 if not isinstance(envelope.message, OefSearchMessage): # pragma: nocover 

215 raise ValueError("Message not of type OefSearchMessage.") 

216 oef_message, dialogue = self._get_message_and_dialogue(envelope) 

217 

218 if dialogue is None: 

219 self.logger.warning( 

220 "Could not create dialogue for message={}".format(oef_message) 

221 ) 

222 return 

223 

224 if oef_message.performative == OefSearchMessage.Performative.REGISTER_SERVICE: 

225 await self._register_service( 

226 envelope.sender, oef_message.service_description 

227 ) 

228 elif ( 

229 oef_message.performative == OefSearchMessage.Performative.UNREGISTER_SERVICE 

230 ): 

231 await self._unregister_service(oef_message, dialogue) 

232 elif oef_message.performative == OefSearchMessage.Performative.SEARCH_SERVICES: 

233 await self._search_services(oef_message, dialogue) 

234 else: 

235 # request not recognized 

236 pass 

237 

238 async def _handle_agent_message(self, envelope: Envelope) -> None: 

239 """ 

240 Forward an envelope to the right agent. 

241 

242 :param envelope: the envelope 

243 """ 

244 destination = envelope.to 

245 

246 if destination not in self._out_queues.keys(): 

247 msg = DefaultMessage( 

248 performative=DefaultMessage.Performative.ERROR, 

249 dialogue_reference=("", ""), 

250 target=TARGET, 

251 message_id=MESSAGE_ID, 

252 error_code=DefaultMessage.ErrorCode.INVALID_DIALOGUE, 

253 error_msg="Destination not available", 

254 error_data={}, 

255 ) 

256 error_envelope = Envelope( 

257 to=envelope.sender, 

258 sender=OEF_LOCAL_NODE_ADDRESS, 

259 message=msg, 

260 ) 

261 await self._send(error_envelope) 

262 return 

263 await self._send(envelope) 

264 

265 async def _register_service( 

266 self, address: Address, service_description: Description 

267 ) -> None: 

268 """ 

269 Register a service agent in the service directory of the node. 

270 

271 :param address: the address of the service agent to be registered. 

272 :param service_description: the description of the service agent to be registered. 

273 """ 

274 with self._lock: 

275 self.services[address].append(service_description) 

276 

277 async def _unregister_service( 

278 self, 

279 oef_search_msg: OefSearchMessage, 

280 dialogue: OefSearchDialogue, 

281 ) -> None: 

282 """ 

283 Unregister a service agent. 

284 

285 :param oef_search_msg: the incoming message. 

286 :param dialogue: the dialogue. 

287 """ 

288 service_description = oef_search_msg.service_description 

289 address = oef_search_msg.sender 

290 with self._lock: 

291 if address not in self.services: 

292 msg = dialogue.reply( 

293 performative=OefSearchMessage.Performative.OEF_ERROR, 

294 target_message=oef_search_msg, 

295 oef_error_operation=OefSearchMessage.OefErrorOperation.UNREGISTER_SERVICE, 

296 ) 

297 envelope = Envelope( 

298 to=msg.to, 

299 sender=msg.sender, 

300 message=msg, 

301 ) 

302 await self._send(envelope) 

303 else: 

304 self.services[address].remove(service_description) 

305 if len(self.services[address]) == 0: 

306 self.services.pop(address) 

307 

308 async def _search_services( 

309 self, 

310 oef_search_msg: OefSearchMessage, 

311 dialogue: OefSearchDialogue, 

312 ) -> None: 

313 """ 

314 Search the agents in the local Service Directory, and send back the result. 

315 

316 This is actually a dummy search, it will return all the registered agents with the specified data model. 

317 If the data model is not specified, it will return all the agents. 

318 

319 :param oef_search_msg: the message. 

320 :param dialogue: the dialogue. 

321 """ 

322 with self._lock: 

323 query = oef_search_msg.query 

324 result = [] # type: List[str] 

325 if query.model is None: 

326 result = list(set(self.services.keys())) 

327 else: 

328 for agent_address, descriptions in self.services.items(): 

329 for description in descriptions: 

330 if description.data_model == query.model: 

331 result.append(agent_address) 

332 

333 msg = dialogue.reply( 

334 performative=OefSearchMessage.Performative.SEARCH_RESULT, 

335 target_message=oef_search_msg, 

336 agents=tuple(sorted(set(result))), 

337 ) 

338 

339 envelope = Envelope( 

340 to=msg.to, 

341 sender=msg.sender, 

342 message=msg, 

343 ) 

344 await self._send(envelope) 

345 

346 def _get_message_and_dialogue( 

347 self, envelope: Envelope 

348 ) -> Tuple[OefSearchMessage, Optional[OefSearchDialogue]]: 

349 """ 

350 Get a message copy and dialogue related to this message. 

351 

352 :param envelope: incoming envelope 

353 

354 :return: Tuple[Message, Optional[Dialogue]] 

355 """ 

356 if self._dialogues is None: # pragma: nocover 

357 raise ValueError("Call connect before!") 

358 message = cast(OefSearchMessage, envelope.message) 

359 dialogue = cast(Optional[OefSearchDialogue], self._dialogues.update(message)) 

360 return message, dialogue 

361 

362 async def _send(self, envelope: Envelope) -> None: 

363 """Send a message.""" 

364 destination = envelope.to 

365 destination_queue = self._out_queues[destination] 

366 destination_queue._loop.call_soon_threadsafe(destination_queue.put_nowait, envelope) # type: ignore # pylint: disable=protected-access 

367 self.logger.debug("Send envelope {}".format(envelope)) 

368 

369 async def disconnect(self, address: Address) -> None: 

370 """ 

371 Disconnect. 

372 

373 :param address: the address of the agent 

374 """ 

375 with self._lock: 

376 self._out_queues.pop(address, None) 

377 self.services.pop(address, None) 

378 

379 

380class OEFLocalConnection(Connection): 

381 """ 

382 Proxy to the functionality of the OEF. 

383 

384 It allows the interaction between agents, but not the search functionality. 

385 It is useful for local testing. 

386 """ 

387 

388 connection_id = PUBLIC_ID 

389 

390 def __init__(self, local_node: Optional[LocalNode] = None, **kwargs: Any) -> None: 

391 """ 

392 Load the connection configuration. 

393 

394 Initialize a OEF proxy for a local OEF Node 

395 

396 :param local_node: the Local OEF Node object. This reference must be the same across the agents of interest. (Note, AEA loader will not accept this argument.) 

397 :param kwargs: keyword arguments. 

398 """ 

399 super().__init__(**kwargs) 

400 self._local_node = local_node 

401 self._reader = None # type: Optional[Queue] 

402 self._writer = None # type: Optional[Queue] 

403 

404 async def connect(self) -> None: 

405 """Connect to the local OEF Node.""" 

406 if self._local_node is None: # pragma: nocover 

407 raise ValueError("No local node set!") 

408 

409 if self.is_connected: # pragma: nocover 

410 return 

411 

412 with self._connect_context(): 

413 self._reader = Queue() 

414 self._writer = await self._local_node.connect(self.address, self._reader) 

415 

416 async def disconnect(self) -> None: 

417 """Disconnect from the local OEF Node.""" 

418 if self._local_node is None: 

419 raise ValueError("No local node set!") # pragma: nocover 

420 if self.is_disconnected: 

421 return # pragma: nocover 

422 self.state = ConnectionStates.disconnecting 

423 if self._reader is None: 

424 raise ValueError("No reader set!") # pragma: nocover 

425 await self._local_node.disconnect(self.address) 

426 await self._reader.put(None) 

427 self._reader, self._writer = None, None 

428 self.state = ConnectionStates.disconnected 

429 

430 async def send(self, envelope: Envelope) -> None: 

431 """ 

432 Send a message. 

433 

434 :param envelope: the envelope. 

435 """ 

436 self._ensure_connected() 

437 self._writer._loop.call_soon_threadsafe(self._writer.put_nowait, envelope) # type: ignore # pylint: disable=protected-access 

438 

439 async def receive(self, *args: Any, **kwargs: Any) -> Optional["Envelope"]: 

440 """ 

441 Receive an envelope. Blocking. 

442 

443 :param args: positional arguments. 

444 :param kwargs: keyword arguments. 

445 :return: the envelope received, or None. 

446 """ 

447 self._ensure_connected() 

448 try: 

449 if self._reader is None: 

450 raise ValueError("No reader set!") # pragma: nocover 

451 envelope = await self._reader.get() 

452 if envelope is None: # pragma: no cover 

453 self.logger.debug("Receiving task terminated.") 

454 return None 

455 self.logger.debug("Received envelope {}".format(envelope)) 

456 return envelope 

457 except Exception: # pragma: nocover # pylint: disable=broad-except 

458 return None