Coverage for src/extratools_cloud/aws/sqs.py: 0%
94 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-21 09:46 -0700
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-21 09:46 -0700
1import gzip
2import json
3from collections.abc import Iterable
4from os import getenv
5from typing import Any, Literal, cast, override
6from uuid import uuid4
8import boto3
9import simple_zstd as zstd
10from boto3.resources.base import ServiceResource
11from extratools_core.crudl import CRUDLDict
12from extratools_core.json import JsonDict
13from extratools_core.str import compress
14from toolz.itertoolz import partition_all
16from ..common.router import BaseRouter
17from .helpers import ClientErrorHandler
19STAGE: str = getenv("STAGE", "local")
22default_service_resource: ServiceResource = boto3.resource(
23 "sqs",
24 endpoint_url=(
25 "http://localhost:4566" if STAGE == "local"
26 else None
27 ),
28)
30type Queue = Any
32FIFO_QUEUE_NAME_SUFFIX = ".fifo"
35def get_queue_json(queue: Queue) -> JsonDict:
36 return {
37 "url": queue.url,
38 "attributes": queue.attributes,
39 }
42def get_resource_dict(
43 *,
44 service_resource: ServiceResource | None = None,
45 queue_name_prefix: str | None = None,
46 json_only: bool = False,
47) -> CRUDLDict[str, Queue | JsonDict]:
48 service_resource = service_resource or default_service_resource
50 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html
52 def check_queue_name(queue_name: str) -> None:
53 if queue_name_prefix and not queue_name.startswith(queue_name_prefix):
54 raise ValueError
56 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None:
57 if queue_name is None:
58 raise ValueError
60 check_queue_name(queue_name)
62 service_resource.create_queue(
63 QueueName=queue_name,
64 Attributes={
65 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(),
66 **attributes,
67 },
68 )
70 @ClientErrorHandler(
71 "QueueDoesNotExist",
72 KeyError,
73 )
74 def read_func(queue_name: str) -> Queue:
75 check_queue_name(queue_name)
77 queue = service_resource.get_queue_by_name(
78 QueueName=queue_name,
79 )
80 if not json_only:
81 return queue
83 return get_queue_json(queue)
85 def update_func(queue_name: str, attributes: dict[str, str]) -> None:
86 check_queue_name(queue_name)
88 service_resource.get_queue_by_name(
89 QueueName=queue_name,
90 ).set_attributes(
91 Attributes={
92 **attributes,
93 },
94 )
96 def delete_func(queue_name: str) -> None:
97 check_queue_name(queue_name)
99 service_resource.get_queue_by_name(
100 QueueName=queue_name,
101 ).delete()
103 def list_func(_: None) -> Iterable[tuple[str, Queue]]:
104 for queue in (
105 service_resource.queues.filter(
106 QueueNamePrefix=queue_name_prefix,
107 )
108 if queue_name_prefix
109 else service_resource.queues.all()
110 ):
111 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1]
112 yield queue_name, (
113 get_queue_json(queue) if json_only
114 else queue
115 )
117 return CRUDLDict[str, Queue](
118 create_func=create_func,
119 read_func=read_func,
120 update_func=update_func,
121 delete_func=delete_func,
122 list_func=list_func,
123 )
126MESSAGE_BATCH_SIZE = 10
129def __encode_body(
130 body: str,
131 *,
132 encoding: Literal["gzip", "zstd"] | None = None,
133) -> str:
134 match encoding:
135 case "gzip":
136 return compress(body, gzip.compress)
137 case "zstd":
138 return compress(body, zstd.compress)
139 case None:
140 return body
141 case _:
142 raise ValueError
145def send_messages(
146 queue: Queue,
147 messages: Iterable[JsonDict],
148 group: str | None = None,
149 *,
150 encoding: Literal["gzip", "zstd"] | None = None,
151) -> Iterable[JsonDict]:
152 batch_id = str(uuid4())
154 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX)
155 if fifo and not group:
156 raise ValueError
158 for message_batch in partition_all(
159 MESSAGE_BATCH_SIZE,
160 (
161 (f"{batch_id}_{i}", message_data)
162 for i, message_data in enumerate(messages)
163 ),
164 ):
165 response: JsonDict = queue.send_messages(Entries=[
166 dict(
167 Id=message_id,
168 MessageBody=__encode_body(
169 json.dumps(message_data),
170 encoding=encoding,
171 ),
172 ) | (
173 dict(
174 MessageAttributes={
175 "ContentEncoding": {
176 "StringValue": encoding,
177 "DataType": "String",
178 },
179 },
180 )
181 if encoding else {}
182 ) | (
183 dict(
184 MessageDeduplicationId=message_id,
185 MessageGroupId=group,
186 )
187 if fifo else {}
188 )
189 for message_id, message_data in message_batch
190 ])
192 yield from response.get("Successful", [])
193 yield from response.get("Failed", [])
196class FifoRouter(BaseRouter[str, str]):
197 """
198 Router utilizing FIFO queues and groups
199 - Each resource is queue base name (excluding specified prefix and `.fifo` suffix)
200 - Each target is group name
201 - Assuming each group name is unique across all queues in router
202 - Each resource is also a target
203 - Including existing ones
204 """
206 def __init__(
207 self,
208 *,
209 service_resource: ServiceResource | None = None,
210 queue_name_prefix: str,
211 default_target_resource: str,
212 encoding: Literal["gzip", "zstd"] | None = None,
213 ) -> None:
214 super().__init__(
215 default_target_resource=default_target_resource,
216 )
218 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict(
219 service_resource=service_resource,
220 queue_name_prefix=queue_name_prefix,
221 )
223 default_queue_name = queue_name_prefix + default_target_resource + FIFO_QUEUE_NAME_SUFFIX
225 self.__queue_name_prefix = queue_name_prefix
227 queue_name_prefix_len = len(queue_name_prefix)
228 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX)
229 self.__queues: dict[str, Queue] = {
230 default_target_resource: self.__resource_dict[default_queue_name],
231 } | {
232 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue
233 for queue_name, queue in self.__resource_dict.items()
234 }
235 for resource in self.__queues:
236 super().register_targets(resource, [resource])
238 self.__encoding: Literal["gzip", "zstd"] | None = encoding
240 @override
241 def register_targets(
242 self,
243 resource: str,
244 targets: Iterable[str],
245 *,
246 create: bool = True,
247 ) -> None:
248 super().register_targets(resource, targets)
249 super().register_targets(resource, [resource])
251 queue_name = self.__queue_name_prefix + resource + FIFO_QUEUE_NAME_SUFFIX
253 if queue_name not in self.__resource_dict:
254 if create:
255 self.__resource_dict[queue_name] = {}
256 else:
257 raise KeyError
259 self.__queues[resource] = self.__resource_dict[queue_name]
261 @override
262 def _route_to_resource(
263 self,
264 data: Iterable[JsonDict],
265 resource: str,
266 target: str,
267 ) -> Iterable[JsonDict]:
268 yield from send_messages(
269 self.__queues[resource],
270 data,
271 target,
272 encoding=self.__encoding,
273 )