Coverage for src/extratools_cloud/aws/sqs.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-21 20:05 -0700

1import gzip 

2import json 

3from collections.abc import Iterable 

4from typing import Any, Literal, cast, override 

5from uuid import uuid4 

6 

7import simple_zstd as zstd 

8from boto3.resources.base import ServiceResource 

9from extratools_core.crudl import CRUDLDict 

10from extratools_core.json import JsonDict 

11from extratools_core.str import compress 

12from toolz.itertoolz import partition_all 

13 

14from ..common.router import BaseRouter 

15from .helpers import ClientErrorHandler, get_service_resource 

16 

17default_service_resource: ServiceResource = get_service_resource("sqs") 

18 

19type Queue = Any 

20 

21FIFO_QUEUE_NAME_SUFFIX = ".fifo" 

22 

23 

24def get_queue_json(queue: Queue) -> JsonDict: 

25 return { 

26 "url": queue.url, 

27 "attributes": queue.attributes, 

28 } 

29 

30 

31def get_resource_dict( 

32 *, 

33 service_resource: ServiceResource | None = None, 

34 queue_name_prefix: str | None = None, 

35 json_only: bool = False, 

36) -> CRUDLDict[str, Queue | JsonDict]: 

37 service_resource = service_resource or default_service_resource 

38 

39 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html 

40 

41 def check_queue_name(queue_name: str) -> None: 

42 if queue_name_prefix and not queue_name.startswith(queue_name_prefix): 

43 raise ValueError 

44 

45 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None: 

46 if queue_name is None: 

47 raise ValueError 

48 

49 check_queue_name(queue_name) 

50 

51 service_resource.create_queue( 

52 QueueName=queue_name, 

53 Attributes={ 

54 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(), 

55 **attributes, 

56 }, 

57 ) 

58 

59 @ClientErrorHandler( 

60 "QueueDoesNotExist", 

61 KeyError, 

62 ) 

63 def read_func(queue_name: str) -> Queue: 

64 check_queue_name(queue_name) 

65 

66 queue = service_resource.get_queue_by_name( 

67 QueueName=queue_name, 

68 ) 

69 if not json_only: 

70 return queue 

71 

72 return get_queue_json(queue) 

73 

74 def update_func(queue_name: str, attributes: dict[str, str]) -> None: 

75 check_queue_name(queue_name) 

76 

77 service_resource.get_queue_by_name( 

78 QueueName=queue_name, 

79 ).set_attributes( 

80 Attributes={ 

81 **attributes, 

82 }, 

83 ) 

84 

85 def delete_func(queue_name: str) -> None: 

86 check_queue_name(queue_name) 

87 

88 service_resource.get_queue_by_name( 

89 QueueName=queue_name, 

90 ).delete() 

91 

92 def list_func(_: None) -> Iterable[tuple[str, Queue]]: 

93 for queue in ( 

94 service_resource.queues.filter( 

95 QueueNamePrefix=queue_name_prefix, 

96 ) 

97 if queue_name_prefix 

98 else service_resource.queues.all() 

99 ): 

100 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1] 

101 yield queue_name, ( 

102 get_queue_json(queue) if json_only 

103 else queue 

104 ) 

105 

106 return CRUDLDict[str, Queue]( 

107 create_func=create_func, 

108 read_func=read_func, 

109 update_func=update_func, 

110 delete_func=delete_func, 

111 list_func=list_func, 

112 ) 

113 

114 

115MESSAGE_BATCH_SIZE = 10 

116 

117 

118def __encode_body( 

119 body: str, 

120 *, 

121 encoding: Literal["gzip", "zstd"] | None = None, 

122) -> str: 

123 match encoding: 

124 case "gzip": 

125 return compress(body, gzip.compress) 

126 case "zstd": 

127 return compress(body, zstd.compress) 

128 case None: 

129 return body 

130 case _: 

131 raise ValueError 

132 

133 

134def send_messages( 

135 queue: Queue, 

136 messages: Iterable[JsonDict], 

137 group: str | None = None, 

138 *, 

139 encoding: Literal["gzip", "zstd"] | None = None, 

140) -> Iterable[JsonDict]: 

141 batch_id = str(uuid4()) 

142 

143 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX) 

144 if fifo and not group: 

145 raise ValueError 

146 

147 for message_batch in partition_all( 

148 MESSAGE_BATCH_SIZE, 

149 ( 

150 (f"{batch_id}_{i}", message_data) 

151 for i, message_data in enumerate(messages) 

152 ), 

153 ): 

154 response: JsonDict = queue.send_messages(Entries=[ 

155 dict( 

156 Id=message_id, 

157 MessageBody=__encode_body( 

158 json.dumps(message_data), 

159 encoding=encoding, 

160 ), 

161 ) | ( 

162 dict( 

163 MessageAttributes={ 

164 "ContentEncoding": { 

165 "StringValue": encoding, 

166 "DataType": "String", 

167 }, 

168 }, 

169 ) 

170 if encoding else {} 

171 ) | ( 

172 dict( 

173 MessageDeduplicationId=message_id, 

174 MessageGroupId=group, 

175 ) 

176 if fifo else {} 

177 ) 

178 for message_id, message_data in message_batch 

179 ]) 

180 

181 yield from response.get("Successful", []) 

182 yield from response.get("Failed", []) 

183 

184 

185class FifoRouter(BaseRouter[str, str]): 

186 """ 

187 Router utilizing FIFO queues and groups 

188 - Each resource is queue base name (excluding specified prefix and `.fifo` suffix) 

189 - Each target is group name 

190 - Assuming each group name is unique across all queues in router 

191 - Each resource is also a target 

192 - Including existing ones 

193 """ 

194 

195 def __init__( 

196 self, 

197 *, 

198 service_resource: ServiceResource | None = None, 

199 queue_name_prefix: str, 

200 default_target_resource: str, 

201 encoding: Literal["gzip", "zstd"] | None = None, 

202 ) -> None: 

203 super().__init__( 

204 default_target_resource=default_target_resource, 

205 ) 

206 

207 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict( 

208 service_resource=service_resource, 

209 queue_name_prefix=queue_name_prefix, 

210 ) 

211 

212 default_queue_name = queue_name_prefix + default_target_resource + FIFO_QUEUE_NAME_SUFFIX 

213 

214 self.__queue_name_prefix = queue_name_prefix 

215 

216 queue_name_prefix_len = len(queue_name_prefix) 

217 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX) 

218 self.__queues: dict[str, Queue] = { 

219 default_target_resource: self.__resource_dict[default_queue_name], 

220 } | { 

221 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue 

222 for queue_name, queue in self.__resource_dict.items() 

223 } 

224 for resource in self.__queues: 

225 super().register_targets(resource, [resource]) 

226 

227 self.__encoding: Literal["gzip", "zstd"] | None = encoding 

228 

229 @override 

230 def register_targets( 

231 self, 

232 resource: str, 

233 targets: Iterable[str], 

234 *, 

235 create: bool = True, 

236 ) -> None: 

237 super().register_targets(resource, targets) 

238 super().register_targets(resource, [resource]) 

239 

240 queue_name = self.__queue_name_prefix + resource + FIFO_QUEUE_NAME_SUFFIX 

241 

242 if queue_name not in self.__resource_dict: 

243 if create: 

244 self.__resource_dict[queue_name] = {} 

245 else: 

246 raise KeyError 

247 

248 self.__queues[resource] = self.__resource_dict[queue_name] 

249 

250 @override 

251 def _route_to_resource( 

252 self, 

253 data: Iterable[JsonDict], 

254 resource: str, 

255 target: str, 

256 ) -> Iterable[JsonDict]: 

257 yield from send_messages( 

258 self.__queues[resource], 

259 data, 

260 target, 

261 encoding=self.__encoding, 

262 )