Coverage for src/extratools_cloud/aws/sqs.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-21 09:46 -0700

1import gzip 

2import json 

3from collections.abc import Iterable 

4from os import getenv 

5from typing import Any, Literal, cast, override 

6from uuid import uuid4 

7 

8import boto3 

9import simple_zstd as zstd 

10from boto3.resources.base import ServiceResource 

11from extratools_core.crudl import CRUDLDict 

12from extratools_core.json import JsonDict 

13from extratools_core.str import compress 

14from toolz.itertoolz import partition_all 

15 

16from ..common.router import BaseRouter 

17from .helpers import ClientErrorHandler 

18 

19STAGE: str = getenv("STAGE", "local") 

20 

21 

22default_service_resource: ServiceResource = boto3.resource( 

23 "sqs", 

24 endpoint_url=( 

25 "http://localhost:4566" if STAGE == "local" 

26 else None 

27 ), 

28) 

29 

30type Queue = Any 

31 

32FIFO_QUEUE_NAME_SUFFIX = ".fifo" 

33 

34 

35def get_queue_json(queue: Queue) -> JsonDict: 

36 return { 

37 "url": queue.url, 

38 "attributes": queue.attributes, 

39 } 

40 

41 

42def get_resource_dict( 

43 *, 

44 service_resource: ServiceResource | None = None, 

45 queue_name_prefix: str | None = None, 

46 json_only: bool = False, 

47) -> CRUDLDict[str, Queue | JsonDict]: 

48 service_resource = service_resource or default_service_resource 

49 

50 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html 

51 

52 def check_queue_name(queue_name: str) -> None: 

53 if queue_name_prefix and not queue_name.startswith(queue_name_prefix): 

54 raise ValueError 

55 

56 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None: 

57 if queue_name is None: 

58 raise ValueError 

59 

60 check_queue_name(queue_name) 

61 

62 service_resource.create_queue( 

63 QueueName=queue_name, 

64 Attributes={ 

65 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(), 

66 **attributes, 

67 }, 

68 ) 

69 

70 @ClientErrorHandler( 

71 "QueueDoesNotExist", 

72 KeyError, 

73 ) 

74 def read_func(queue_name: str) -> Queue: 

75 check_queue_name(queue_name) 

76 

77 queue = service_resource.get_queue_by_name( 

78 QueueName=queue_name, 

79 ) 

80 if not json_only: 

81 return queue 

82 

83 return get_queue_json(queue) 

84 

85 def update_func(queue_name: str, attributes: dict[str, str]) -> None: 

86 check_queue_name(queue_name) 

87 

88 service_resource.get_queue_by_name( 

89 QueueName=queue_name, 

90 ).set_attributes( 

91 Attributes={ 

92 **attributes, 

93 }, 

94 ) 

95 

96 def delete_func(queue_name: str) -> None: 

97 check_queue_name(queue_name) 

98 

99 service_resource.get_queue_by_name( 

100 QueueName=queue_name, 

101 ).delete() 

102 

103 def list_func(_: None) -> Iterable[tuple[str, Queue]]: 

104 for queue in ( 

105 service_resource.queues.filter( 

106 QueueNamePrefix=queue_name_prefix, 

107 ) 

108 if queue_name_prefix 

109 else service_resource.queues.all() 

110 ): 

111 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1] 

112 yield queue_name, ( 

113 get_queue_json(queue) if json_only 

114 else queue 

115 ) 

116 

117 return CRUDLDict[str, Queue]( 

118 create_func=create_func, 

119 read_func=read_func, 

120 update_func=update_func, 

121 delete_func=delete_func, 

122 list_func=list_func, 

123 ) 

124 

125 

126MESSAGE_BATCH_SIZE = 10 

127 

128 

129def __encode_body( 

130 body: str, 

131 *, 

132 encoding: Literal["gzip", "zstd"] | None = None, 

133) -> str: 

134 match encoding: 

135 case "gzip": 

136 return compress(body, gzip.compress) 

137 case "zstd": 

138 return compress(body, zstd.compress) 

139 case None: 

140 return body 

141 case _: 

142 raise ValueError 

143 

144 

145def send_messages( 

146 queue: Queue, 

147 messages: Iterable[JsonDict], 

148 group: str | None = None, 

149 *, 

150 encoding: Literal["gzip", "zstd"] | None = None, 

151) -> Iterable[JsonDict]: 

152 batch_id = str(uuid4()) 

153 

154 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX) 

155 if fifo and not group: 

156 raise ValueError 

157 

158 for message_batch in partition_all( 

159 MESSAGE_BATCH_SIZE, 

160 ( 

161 (f"{batch_id}_{i}", message_data) 

162 for i, message_data in enumerate(messages) 

163 ), 

164 ): 

165 response: JsonDict = queue.send_messages(Entries=[ 

166 dict( 

167 Id=message_id, 

168 MessageBody=__encode_body( 

169 json.dumps(message_data), 

170 encoding=encoding, 

171 ), 

172 ) | ( 

173 dict( 

174 MessageAttributes={ 

175 "ContentEncoding": { 

176 "StringValue": encoding, 

177 "DataType": "String", 

178 }, 

179 }, 

180 ) 

181 if encoding else {} 

182 ) | ( 

183 dict( 

184 MessageDeduplicationId=message_id, 

185 MessageGroupId=group, 

186 ) 

187 if fifo else {} 

188 ) 

189 for message_id, message_data in message_batch 

190 ]) 

191 

192 yield from response.get("Successful", []) 

193 yield from response.get("Failed", []) 

194 

195 

196class FifoRouter(BaseRouter[str, str]): 

197 """ 

198 Router utilizing FIFO queues and groups 

199 - Each resource is queue base name (excluding specified prefix and `.fifo` suffix) 

200 - Each target is group name 

201 - Assuming each group name is unique across all queues in router 

202 - Each resource is also a target 

203 - Including existing ones 

204 """ 

205 

206 def __init__( 

207 self, 

208 *, 

209 service_resource: ServiceResource | None = None, 

210 queue_name_prefix: str, 

211 default_target_resource: str, 

212 encoding: Literal["gzip", "zstd"] | None = None, 

213 ) -> None: 

214 super().__init__( 

215 default_target_resource=default_target_resource, 

216 ) 

217 

218 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict( 

219 service_resource=service_resource, 

220 queue_name_prefix=queue_name_prefix, 

221 ) 

222 

223 default_queue_name = queue_name_prefix + default_target_resource + FIFO_QUEUE_NAME_SUFFIX 

224 

225 self.__queue_name_prefix = queue_name_prefix 

226 

227 queue_name_prefix_len = len(queue_name_prefix) 

228 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX) 

229 self.__queues: dict[str, Queue] = { 

230 default_target_resource: self.__resource_dict[default_queue_name], 

231 } | { 

232 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue 

233 for queue_name, queue in self.__resource_dict.items() 

234 } 

235 for resource in self.__queues: 

236 super().register_targets(resource, [resource]) 

237 

238 self.__encoding: Literal["gzip", "zstd"] | None = encoding 

239 

240 @override 

241 def register_targets( 

242 self, 

243 resource: str, 

244 targets: Iterable[str], 

245 *, 

246 create: bool = True, 

247 ) -> None: 

248 super().register_targets(resource, targets) 

249 super().register_targets(resource, [resource]) 

250 

251 queue_name = self.__queue_name_prefix + resource + FIFO_QUEUE_NAME_SUFFIX 

252 

253 if queue_name not in self.__resource_dict: 

254 if create: 

255 self.__resource_dict[queue_name] = {} 

256 else: 

257 raise KeyError 

258 

259 self.__queues[resource] = self.__resource_dict[queue_name] 

260 

261 @override 

262 def _route_to_resource( 

263 self, 

264 data: Iterable[JsonDict], 

265 resource: str, 

266 target: str, 

267 ) -> Iterable[JsonDict]: 

268 yield from send_messages( 

269 self.__queues[resource], 

270 data, 

271 target, 

272 encoding=self.__encoding, 

273 )