Coverage for src/extratools_cloud/aws/sqs.py: 0%
86 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 04:26 -0700
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 04:26 -0700
1import gzip
2import json
3from base64 import b64encode
4from collections.abc import Iterable
5from os import getenv
6from typing import Any, Literal, cast
7from uuid import uuid4
9import boto3
10import simple_zstd as zstd
11from boto3.resources.base import ServiceResource
12from extratools_core.crudl import CRUDLDict
13from extratools_core.json import JsonDict
14from toolz.itertoolz import partition_all
16from .helpers import ClientErrorHandler
18STAGE: str = getenv("STAGE", "local")
21default_service_resource: ServiceResource = boto3.resource(
22 "sqs",
23 endpoint_url=(
24 "http://localhost:4566" if STAGE == "local"
25 else None
26 ),
27)
29type Queue = Any
31FIFO_QUEUE_NAME_SUFFIX = ".fifo"
34def get_queue_json(queue: Queue) -> JsonDict:
35 return {
36 "url": queue.url,
37 "attributes": queue.attributes,
38 }
41def get_resource_dict(
42 *,
43 service_resource: ServiceResource | None = None,
44 queue_name_prefix: str | None = None,
45 json_only: bool = False,
46) -> CRUDLDict[str, Queue | JsonDict]:
47 service_resource = service_resource or default_service_resource
49 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html
51 def check_queue_name(queue_name: str) -> None:
52 if queue_name_prefix and not queue_name.startswith(queue_name_prefix):
53 raise ValueError
55 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None:
56 if queue_name is None:
57 raise ValueError
59 check_queue_name(queue_name)
61 service_resource.create_queue(
62 QueueName=queue_name,
63 Attributes={
64 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(),
65 **attributes,
66 },
67 )
69 @ClientErrorHandler(
70 "QueueDoesNotExist",
71 KeyError,
72 )
73 def read_func(queue_name: str) -> Queue:
74 check_queue_name(queue_name)
76 queue = service_resource.get_queue_by_name(
77 QueueName=queue_name,
78 )
79 if not json_only:
80 return queue
82 return get_queue_json(queue)
84 def update_func(queue_name: str, attributes: dict[str, str]) -> None:
85 check_queue_name(queue_name)
87 service_resource.get_queue_by_name(
88 QueueName=queue_name,
89 ).set_attributes(
90 Attributes={
91 **attributes,
92 },
93 )
95 def delete_func(queue_name: str) -> None:
96 check_queue_name(queue_name)
98 service_resource.get_queue_by_name(
99 QueueName=queue_name,
100 ).delete()
102 def list_func(_: None) -> Iterable[tuple[str, Queue]]:
103 for queue in (
104 service_resource.queues.filter(
105 QueueNamePrefix=queue_name_prefix,
106 )
107 if queue_name_prefix
108 else service_resource.queues.all()
109 ):
110 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1]
111 yield queue_name, (
112 get_queue_json(queue) if json_only
113 else queue
114 )
116 return CRUDLDict[str, Queue](
117 create_func=create_func,
118 read_func=read_func,
119 update_func=update_func,
120 delete_func=delete_func,
121 list_func=list_func,
122 )
125MESSAGE_BATCH_SIZE = 10
128def __encode_body(
129 body: str,
130 *,
131 encoding: Literal["gzip", "zstd"] | None = None,
132) -> str:
133 match encoding:
134 case "gzip":
135 return b64encode(gzip.compress(body.encode())).decode()
136 case "zstd":
137 return b64encode(zstd.compress(body.encode())).decode()
138 case None:
139 return body
140 case _:
141 raise ValueError
144def send_messages(
145 queue: Queue,
146 messages: Iterable[JsonDict],
147 group: str | None = None,
148 *,
149 encoding: Literal["gzip", "zstd"] | None = None,
150) -> Iterable[JsonDict]:
151 batch_id = str(uuid4())
153 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX)
154 if fifo and not group:
155 raise ValueError
157 for message_batch in partition_all(
158 MESSAGE_BATCH_SIZE,
159 (
160 (f"{batch_id}_{i}", message_data)
161 for i, message_data in enumerate(messages)
162 ),
163 ):
164 response: JsonDict = queue.send_messages(Entries=[
165 dict(
166 Id=message_id,
167 MessageBody=__encode_body(
168 json.dumps(message_data),
169 encoding=encoding,
170 ),
171 ) | (
172 dict(
173 MessageAttributes={
174 "ContentEncoding": {
175 "StringValue": encoding,
176 "DataType": "String",
177 },
178 },
179 )
180 if encoding else {}
181 ) | (
182 dict(
183 MessageDeduplicationId=message_id,
184 MessageGroupId=group,
185 )
186 if fifo else {}
187 )
188 for message_id, message_data in message_batch
189 ])
191 yield from response.get("Successful", [])
192 yield from response.get("Failed", [])
195class FifoRouter:
196 def __init__(
197 self,
198 *,
199 service_resource: ServiceResource | None = None,
200 queue_name_prefix: str,
201 default_queue_base_name: str,
202 ) -> None:
203 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict(
204 service_resource=service_resource,
205 queue_name_prefix=queue_name_prefix,
206 )
208 self.__default_queue: Queue = self.__resource_dict[
209 queue_name_prefix + default_queue_base_name + FIFO_QUEUE_NAME_SUFFIX
210 ]
212 self.__queue_name_prefix = queue_name_prefix
214 queue_name_prefix_len = len(queue_name_prefix)
215 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX)
216 self.__queues: dict[str, Queue] = {
217 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue
218 for queue_name, queue in self.__resource_dict.items()
219 }
221 def register_queue(
222 self,
223 queue_base_name: str,
224 groups: Iterable[str] | None = None,
225 *,
226 create: bool = True,
227 ) -> None:
228 queue_name = self.__queue_name_prefix + queue_base_name + FIFO_QUEUE_NAME_SUFFIX
230 if queue_name not in self.__resource_dict:
231 if create:
232 self.__resource_dict[queue_name] = {}
233 else:
234 raise KeyError
236 queue = self.__resource_dict[queue_name]
237 self.__queues.update(dict.fromkeys(groups or [queue_base_name], queue))
239 def send_messages(
240 self,
241 messages: Iterable[JsonDict],
242 group: str,
243 *,
244 encoding: Literal["gzip", "zstd"] | None = None,
245 ) -> Iterable[JsonDict]:
246 yield from send_messages(
247 self.__queues.get(group, self.__default_queue),
248 messages,
249 group,
250 encoding=encoding,
251 )