Coverage for connection.py: 99%
186 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-02 12:06 +0000
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-02 12:06 +0000
1# -*- coding: utf-8 -*-
2# ------------------------------------------------------------------------------
3#
4# Copyright 2022 Valory AG
5# Copyright 2018-2021 Fetch.AI Limited
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19# ------------------------------------------------------------------------------
20"""Extension to the Local Node."""
21import asyncio
22import logging
23import threading
24from asyncio import AbstractEventLoop, Queue
25from collections import defaultdict
26from concurrent.futures import Future
27from threading import Thread
28from typing import Any, Dict, List, Optional, Tuple, cast
30from aea.common import Address
31from aea.configurations.base import PublicId
32from aea.connections.base import Connection, ConnectionStates
33from aea.helpers.search.models import Description
34from aea.mail.base import Envelope
35from aea.protocols.base import Message
36from aea.protocols.dialogue.base import Dialogue as BaseDialogue
38from packages.fetchai.protocols.default.message import DefaultMessage
39from packages.fetchai.protocols.oef_search.dialogues import (
40 OefSearchDialogue as BaseOefSearchDialogue,
41)
42from packages.fetchai.protocols.oef_search.dialogues import (
43 OefSearchDialogues as BaseOefSearchDialogues,
44)
45from packages.fetchai.protocols.oef_search.message import OefSearchMessage
48_default_logger = logging.getLogger("aea.packages.fetchai.connections.local")
50TARGET = 0
51MESSAGE_ID = 1
52RESPONSE_TARGET = MESSAGE_ID
53RESPONSE_MESSAGE_ID = MESSAGE_ID + 1
54STUB_DIALOGUE_ID = 0
56PUBLIC_ID = PublicId.from_str("fetchai/local:0.20.0")
59OefSearchDialogue = BaseOefSearchDialogue
60OEF_LOCAL_NODE_SEARCH_ADDRESS = "oef_local_node_search"
61OEF_LOCAL_NODE_ADDRESS = "oef_local_node"
64class OefSearchDialogues(BaseOefSearchDialogues):
65 """The dialogues class keeps track of all dialogues."""
67 def __init__(self) -> None:
68 """Initialize dialogues."""
70 def role_from_first_message( # pylint: disable=unused-argument
71 message: Message, receiver_address: Address
72 ) -> BaseDialogue.Role:
73 """Infer the role of the agent from an incoming/outgoing first message
75 :param message: an incoming/outgoing first message
76 :param receiver_address: the address of the receiving agent
77 :return: The role of the agent
78 """
79 # The local connection maintains the dialogue on behalf of the node
80 return OefSearchDialogue.Role.OEF_NODE
82 BaseOefSearchDialogues.__init__(
83 self,
84 self_address=OEF_LOCAL_NODE_SEARCH_ADDRESS,
85 role_from_first_message=role_from_first_message,
86 dialogue_class=OefSearchDialogue,
87 )
90class LocalNode:
91 """A light-weight local implementation of a OEF Node."""
93 def __init__(
94 self, loop: AbstractEventLoop = None, logger: logging.Logger = _default_logger
95 ):
96 """
97 Initialize a local (i.e. non-networked) implementation of an OEF Node.
99 :param loop: the event loop. If None, a new event loop is instantiated.
100 :param logger: the logger.
101 """
102 self._lock = threading.Lock()
103 self.services = defaultdict(lambda: []) # type: Dict[str, List[Description]]
104 self._loop = loop if loop is not None else asyncio.new_event_loop()
105 self._thread = Thread(target=self._run_loop, daemon=True)
107 self._in_queue = None # type: Optional[asyncio.Queue]
108 self._out_queues = {} # type: Dict[str, asyncio.Queue]
110 self._receiving_loop_task = None # type: Optional[Future]
111 self.address: Optional[Address] = None
112 self._dialogues: Optional[OefSearchDialogues] = None
113 self.logger = logger
114 self.started_event = threading.Event()
116 def __enter__(self) -> "LocalNode":
117 """Start the local node."""
118 self.start()
119 return self
121 def __exit__(self, exc_type: str, exc_val: str, exc_tb: str) -> None:
122 """Stop the local node."""
123 self.stop()
125 def _run_loop(self) -> None:
126 """
127 Run the asyncio loop.
129 This method is supposed to be run only in the Multiplexer thread.
130 """
131 self.logger.debug("Starting threaded asyncio loop...")
132 asyncio.set_event_loop(self._loop)
133 self._loop.run_forever()
134 self.logger.debug("Asyncio loop has been stopped.")
136 async def connect(
137 self, address: Address, writer: asyncio.Queue
138 ) -> Optional[asyncio.Queue]:
139 """
140 Connect an address to the node.
142 :param address: the address of the agent.
143 :param writer: the queue where the client is listening.
144 :return: an asynchronous queue, that constitutes the communication channel.
145 """
146 if address in self._out_queues.keys():
147 return None
149 if self._in_queue is None: # pragma: nocover
150 raise ValueError("In queue not set.")
151 q = self._in_queue # type: asyncio.Queue
152 self._out_queues[address] = writer
154 self.address = address
155 self._dialogues = OefSearchDialogues()
156 return q
158 def start(self) -> None:
159 """Start the node."""
160 if not self._loop.is_running() and not self._thread.is_alive():
161 self._thread.start()
162 self._receiving_loop_task = asyncio.run_coroutine_threadsafe(
163 self.receiving_loop(), loop=self._loop
164 )
165 self.started_event.wait()
166 self.logger.debug("Local node has been started.")
168 def stop(self) -> None:
169 """Stop the node."""
171 if self._receiving_loop_task is None or self._in_queue is None:
172 raise ValueError("Connection not started!")
173 asyncio.run_coroutine_threadsafe(self._in_queue.put(None), self._loop).result()
174 self._receiving_loop_task.result()
176 if self._loop.is_running():
177 self._loop.call_soon_threadsafe(self._loop.stop)
178 if self._thread.is_alive():
179 self._thread.join()
181 async def receiving_loop(self) -> None:
182 """Process incoming messages."""
183 self._in_queue = asyncio.Queue()
184 self.started_event.set()
185 while True:
186 envelope = await self._in_queue.get()
187 if envelope is None:
188 self.logger.debug("Receiving loop terminated.")
189 return
190 self.logger.debug("Handling envelope: {}".format(envelope))
191 await self._handle_envelope(envelope)
193 async def _handle_envelope(self, envelope: Envelope) -> None:
194 """Handle an envelope.
196 :param envelope: the envelope
197 """
198 if (
199 envelope.protocol_specification_id
200 == OefSearchMessage.protocol_specification_id
201 ):
202 await self._handle_oef_message(envelope)
203 else:
204 OEFLocalConnection._ensure_valid_envelope_for_external_comms( # pylint: disable=protected-access
205 envelope
206 )
207 await self._handle_agent_message(envelope)
209 async def _handle_oef_message(self, envelope: Envelope) -> None:
210 """Handle oef messages.
212 :param envelope: the envelope
213 """
214 if not isinstance(envelope.message, OefSearchMessage): # pragma: nocover
215 raise ValueError("Message not of type OefSearchMessage.")
216 oef_message, dialogue = self._get_message_and_dialogue(envelope)
218 if dialogue is None:
219 self.logger.warning(
220 "Could not create dialogue for message={}".format(oef_message)
221 )
222 return
224 if oef_message.performative == OefSearchMessage.Performative.REGISTER_SERVICE:
225 await self._register_service(
226 envelope.sender, oef_message.service_description
227 )
228 elif (
229 oef_message.performative == OefSearchMessage.Performative.UNREGISTER_SERVICE
230 ):
231 await self._unregister_service(oef_message, dialogue)
232 elif oef_message.performative == OefSearchMessage.Performative.SEARCH_SERVICES:
233 await self._search_services(oef_message, dialogue)
234 else:
235 # request not recognized
236 pass
238 async def _handle_agent_message(self, envelope: Envelope) -> None:
239 """
240 Forward an envelope to the right agent.
242 :param envelope: the envelope
243 """
244 destination = envelope.to
246 if destination not in self._out_queues.keys():
247 msg = DefaultMessage(
248 performative=DefaultMessage.Performative.ERROR,
249 dialogue_reference=("", ""),
250 target=TARGET,
251 message_id=MESSAGE_ID,
252 error_code=DefaultMessage.ErrorCode.INVALID_DIALOGUE,
253 error_msg="Destination not available",
254 error_data={},
255 )
256 error_envelope = Envelope(
257 to=envelope.sender,
258 sender=OEF_LOCAL_NODE_ADDRESS,
259 message=msg,
260 )
261 await self._send(error_envelope)
262 return
263 await self._send(envelope)
265 async def _register_service(
266 self, address: Address, service_description: Description
267 ) -> None:
268 """
269 Register a service agent in the service directory of the node.
271 :param address: the address of the service agent to be registered.
272 :param service_description: the description of the service agent to be registered.
273 """
274 with self._lock:
275 self.services[address].append(service_description)
277 async def _unregister_service(
278 self,
279 oef_search_msg: OefSearchMessage,
280 dialogue: OefSearchDialogue,
281 ) -> None:
282 """
283 Unregister a service agent.
285 :param oef_search_msg: the incoming message.
286 :param dialogue: the dialogue.
287 """
288 service_description = oef_search_msg.service_description
289 address = oef_search_msg.sender
290 with self._lock:
291 if address not in self.services:
292 msg = dialogue.reply(
293 performative=OefSearchMessage.Performative.OEF_ERROR,
294 target_message=oef_search_msg,
295 oef_error_operation=OefSearchMessage.OefErrorOperation.UNREGISTER_SERVICE,
296 )
297 envelope = Envelope(
298 to=msg.to,
299 sender=msg.sender,
300 message=msg,
301 )
302 await self._send(envelope)
303 else:
304 self.services[address].remove(service_description)
305 if len(self.services[address]) == 0:
306 self.services.pop(address)
308 async def _search_services(
309 self,
310 oef_search_msg: OefSearchMessage,
311 dialogue: OefSearchDialogue,
312 ) -> None:
313 """
314 Search the agents in the local Service Directory, and send back the result.
316 This is actually a dummy search, it will return all the registered agents with the specified data model.
317 If the data model is not specified, it will return all the agents.
319 :param oef_search_msg: the message.
320 :param dialogue: the dialogue.
321 """
322 with self._lock:
323 query = oef_search_msg.query
324 result = [] # type: List[str]
325 if query.model is None:
326 result = list(set(self.services.keys()))
327 else:
328 for agent_address, descriptions in self.services.items():
329 for description in descriptions:
330 if description.data_model == query.model:
331 result.append(agent_address)
333 msg = dialogue.reply(
334 performative=OefSearchMessage.Performative.SEARCH_RESULT,
335 target_message=oef_search_msg,
336 agents=tuple(sorted(set(result))),
337 )
339 envelope = Envelope(
340 to=msg.to,
341 sender=msg.sender,
342 message=msg,
343 )
344 await self._send(envelope)
346 def _get_message_and_dialogue(
347 self, envelope: Envelope
348 ) -> Tuple[OefSearchMessage, Optional[OefSearchDialogue]]:
349 """
350 Get a message copy and dialogue related to this message.
352 :param envelope: incoming envelope
354 :return: Tuple[Message, Optional[Dialogue]]
355 """
356 if self._dialogues is None: # pragma: nocover
357 raise ValueError("Call connect before!")
358 message = cast(OefSearchMessage, envelope.message)
359 dialogue = cast(Optional[OefSearchDialogue], self._dialogues.update(message))
360 return message, dialogue
362 async def _send(self, envelope: Envelope) -> None:
363 """Send a message."""
364 destination = envelope.to
365 destination_queue = self._out_queues[destination]
366 destination_queue._loop.call_soon_threadsafe(destination_queue.put_nowait, envelope) # type: ignore # pylint: disable=protected-access
367 self.logger.debug("Send envelope {}".format(envelope))
369 async def disconnect(self, address: Address) -> None:
370 """
371 Disconnect.
373 :param address: the address of the agent
374 """
375 with self._lock:
376 self._out_queues.pop(address, None)
377 self.services.pop(address, None)
380class OEFLocalConnection(Connection):
381 """
382 Proxy to the functionality of the OEF.
384 It allows the interaction between agents, but not the search functionality.
385 It is useful for local testing.
386 """
388 connection_id = PUBLIC_ID
390 def __init__(self, local_node: Optional[LocalNode] = None, **kwargs: Any) -> None:
391 """
392 Load the connection configuration.
394 Initialize a OEF proxy for a local OEF Node
396 :param local_node: the Local OEF Node object. This reference must be the same across the agents of interest. (Note, AEA loader will not accept this argument.)
397 :param kwargs: keyword arguments.
398 """
399 super().__init__(**kwargs)
400 self._local_node = local_node
401 self._reader = None # type: Optional[Queue]
402 self._writer = None # type: Optional[Queue]
404 async def connect(self) -> None:
405 """Connect to the local OEF Node."""
406 if self._local_node is None: # pragma: nocover
407 raise ValueError("No local node set!")
409 if self.is_connected: # pragma: nocover
410 return
412 with self._connect_context():
413 self._reader = Queue()
414 self._writer = await self._local_node.connect(self.address, self._reader)
416 async def disconnect(self) -> None:
417 """Disconnect from the local OEF Node."""
418 if self._local_node is None:
419 raise ValueError("No local node set!") # pragma: nocover
420 if self.is_disconnected:
421 return # pragma: nocover
422 self.state = ConnectionStates.disconnecting
423 if self._reader is None:
424 raise ValueError("No reader set!") # pragma: nocover
425 await self._local_node.disconnect(self.address)
426 await self._reader.put(None)
427 self._reader, self._writer = None, None
428 self.state = ConnectionStates.disconnected
430 async def send(self, envelope: Envelope) -> None:
431 """
432 Send a message.
434 :param envelope: the envelope.
435 """
436 self._ensure_connected()
437 self._writer._loop.call_soon_threadsafe(self._writer.put_nowait, envelope) # type: ignore # pylint: disable=protected-access
439 async def receive(self, *args: Any, **kwargs: Any) -> Optional["Envelope"]:
440 """
441 Receive an envelope. Blocking.
443 :param args: positional arguments.
444 :param kwargs: keyword arguments.
445 :return: the envelope received, or None.
446 """
447 self._ensure_connected()
448 try:
449 if self._reader is None:
450 raise ValueError("No reader set!") # pragma: nocover
451 envelope = await self._reader.get()
452 if envelope is None: # pragma: no cover
453 self.logger.debug("Receiving task terminated.")
454 return None
455 self.logger.debug("Received envelope {}".format(envelope))
456 return envelope
457 except Exception: # pragma: nocover # pylint: disable=broad-except
458 return None