Coverage for src/extratools_cloud/aws/sqs.py: 0%
86 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 20:51 -0700
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-20 20:51 -0700
1import gzip
2import json
3from collections.abc import Iterable
4from os import getenv
5from typing import Any, Literal, cast
6from uuid import uuid4
8import boto3
9import simple_zstd as zstd
10from boto3.resources.base import ServiceResource
11from extratools_core.crudl import CRUDLDict
12from extratools_core.json import JsonDict
13from extratools_core.str import compress
14from toolz.itertoolz import partition_all
16from .helpers import ClientErrorHandler
18STAGE: str = getenv("STAGE", "local")
21default_service_resource: ServiceResource = boto3.resource(
22 "sqs",
23 endpoint_url=(
24 "http://localhost:4566" if STAGE == "local"
25 else None
26 ),
27)
29type Queue = Any
31FIFO_QUEUE_NAME_SUFFIX = ".fifo"
34def get_queue_json(queue: Queue) -> JsonDict:
35 return {
36 "url": queue.url,
37 "attributes": queue.attributes,
38 }
41def get_resource_dict(
42 *,
43 service_resource: ServiceResource | None = None,
44 queue_name_prefix: str | None = None,
45 json_only: bool = False,
46) -> CRUDLDict[str, Queue | JsonDict]:
47 service_resource = service_resource or default_service_resource
49 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html
51 def check_queue_name(queue_name: str) -> None:
52 if queue_name_prefix and not queue_name.startswith(queue_name_prefix):
53 raise ValueError
55 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None:
56 if queue_name is None:
57 raise ValueError
59 check_queue_name(queue_name)
61 service_resource.create_queue(
62 QueueName=queue_name,
63 Attributes={
64 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(),
65 **attributes,
66 },
67 )
69 @ClientErrorHandler(
70 "QueueDoesNotExist",
71 KeyError,
72 )
73 def read_func(queue_name: str) -> Queue:
74 check_queue_name(queue_name)
76 queue = service_resource.get_queue_by_name(
77 QueueName=queue_name,
78 )
79 if not json_only:
80 return queue
82 return get_queue_json(queue)
84 def update_func(queue_name: str, attributes: dict[str, str]) -> None:
85 check_queue_name(queue_name)
87 service_resource.get_queue_by_name(
88 QueueName=queue_name,
89 ).set_attributes(
90 Attributes={
91 **attributes,
92 },
93 )
95 def delete_func(queue_name: str) -> None:
96 check_queue_name(queue_name)
98 service_resource.get_queue_by_name(
99 QueueName=queue_name,
100 ).delete()
102 def list_func(_: None) -> Iterable[tuple[str, Queue]]:
103 for queue in (
104 service_resource.queues.filter(
105 QueueNamePrefix=queue_name_prefix,
106 )
107 if queue_name_prefix
108 else service_resource.queues.all()
109 ):
110 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1]
111 yield queue_name, (
112 get_queue_json(queue) if json_only
113 else queue
114 )
116 return CRUDLDict[str, Queue](
117 create_func=create_func,
118 read_func=read_func,
119 update_func=update_func,
120 delete_func=delete_func,
121 list_func=list_func,
122 )
125MESSAGE_BATCH_SIZE = 10
128def __encode_body(
129 body: str,
130 *,
131 encoding: Literal["gzip", "zstd"] | None = None,
132) -> str:
133 match encoding:
134 case "gzip":
135 return compress(body, gzip.compress)
136 case "zstd":
137 return compress(body, zstd.compress)
138 case None:
139 return body
140 case _:
141 raise ValueError
144def send_messages(
145 queue: Queue,
146 messages: Iterable[JsonDict],
147 group: str | None = None,
148 *,
149 encoding: Literal["gzip", "zstd"] | None = None,
150) -> Iterable[JsonDict]:
151 batch_id = str(uuid4())
153 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX)
154 if fifo and not group:
155 raise ValueError
157 for message_batch in partition_all(
158 MESSAGE_BATCH_SIZE,
159 (
160 (f"{batch_id}_{i}", message_data)
161 for i, message_data in enumerate(messages)
162 ),
163 ):
164 response: JsonDict = queue.send_messages(Entries=[
165 dict(
166 Id=message_id,
167 MessageBody=__encode_body(
168 json.dumps(message_data),
169 encoding=encoding,
170 ),
171 ) | (
172 dict(
173 MessageAttributes={
174 "ContentEncoding": {
175 "StringValue": encoding,
176 "DataType": "String",
177 },
178 },
179 )
180 if encoding else {}
181 ) | (
182 dict(
183 MessageDeduplicationId=message_id,
184 MessageGroupId=group,
185 )
186 if fifo else {}
187 )
188 for message_id, message_data in message_batch
189 ])
191 yield from response.get("Successful", [])
192 yield from response.get("Failed", [])
195class FifoRouter:
196 def __init__(
197 self,
198 *,
199 service_resource: ServiceResource | None = None,
200 queue_name_prefix: str,
201 default_queue_base_name: str,
202 ) -> None:
203 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict(
204 service_resource=service_resource,
205 queue_name_prefix=queue_name_prefix,
206 )
208 self.__default_queue: Queue = self.__resource_dict[
209 queue_name_prefix + default_queue_base_name + FIFO_QUEUE_NAME_SUFFIX
210 ]
212 self.__queue_name_prefix = queue_name_prefix
214 queue_name_prefix_len = len(queue_name_prefix)
215 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX)
216 self.__queues: dict[str, Queue] = {
217 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue
218 for queue_name, queue in self.__resource_dict.items()
219 }
221 def register_queue(
222 self,
223 queue_base_name: str,
224 groups: Iterable[str] | None = None,
225 *,
226 create: bool = True,
227 ) -> None:
228 queue_name = self.__queue_name_prefix + queue_base_name + FIFO_QUEUE_NAME_SUFFIX
230 if queue_name not in self.__resource_dict:
231 if create:
232 self.__resource_dict[queue_name] = {}
233 else:
234 raise KeyError
236 queue = self.__resource_dict[queue_name]
237 self.__queues.update(dict.fromkeys(groups or [queue_base_name], queue))
239 def send_messages(
240 self,
241 messages: Iterable[JsonDict],
242 group: str,
243 *,
244 encoding: Literal["gzip", "zstd"] | None = None,
245 ) -> Iterable[JsonDict]:
246 yield from send_messages(
247 self.__queues.get(group, self.__default_queue),
248 messages,
249 group,
250 encoding=encoding,
251 )