Coverage for src/extratools_cloud/aws/sqs.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-20 20:51 -0700

1import gzip 

2import json 

3from collections.abc import Iterable 

4from os import getenv 

5from typing import Any, Literal, cast 

6from uuid import uuid4 

7 

8import boto3 

9import simple_zstd as zstd 

10from boto3.resources.base import ServiceResource 

11from extratools_core.crudl import CRUDLDict 

12from extratools_core.json import JsonDict 

13from extratools_core.str import compress 

14from toolz.itertoolz import partition_all 

15 

16from .helpers import ClientErrorHandler 

17 

18STAGE: str = getenv("STAGE", "local") 

19 

20 

21default_service_resource: ServiceResource = boto3.resource( 

22 "sqs", 

23 endpoint_url=( 

24 "http://localhost:4566" if STAGE == "local" 

25 else None 

26 ), 

27) 

28 

29type Queue = Any 

30 

31FIFO_QUEUE_NAME_SUFFIX = ".fifo" 

32 

33 

34def get_queue_json(queue: Queue) -> JsonDict: 

35 return { 

36 "url": queue.url, 

37 "attributes": queue.attributes, 

38 } 

39 

40 

41def get_resource_dict( 

42 *, 

43 service_resource: ServiceResource | None = None, 

44 queue_name_prefix: str | None = None, 

45 json_only: bool = False, 

46) -> CRUDLDict[str, Queue | JsonDict]: 

47 service_resource = service_resource or default_service_resource 

48 

49 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html 

50 

51 def check_queue_name(queue_name: str) -> None: 

52 if queue_name_prefix and not queue_name.startswith(queue_name_prefix): 

53 raise ValueError 

54 

55 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None: 

56 if queue_name is None: 

57 raise ValueError 

58 

59 check_queue_name(queue_name) 

60 

61 service_resource.create_queue( 

62 QueueName=queue_name, 

63 Attributes={ 

64 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(), 

65 **attributes, 

66 }, 

67 ) 

68 

69 @ClientErrorHandler( 

70 "QueueDoesNotExist", 

71 KeyError, 

72 ) 

73 def read_func(queue_name: str) -> Queue: 

74 check_queue_name(queue_name) 

75 

76 queue = service_resource.get_queue_by_name( 

77 QueueName=queue_name, 

78 ) 

79 if not json_only: 

80 return queue 

81 

82 return get_queue_json(queue) 

83 

84 def update_func(queue_name: str, attributes: dict[str, str]) -> None: 

85 check_queue_name(queue_name) 

86 

87 service_resource.get_queue_by_name( 

88 QueueName=queue_name, 

89 ).set_attributes( 

90 Attributes={ 

91 **attributes, 

92 }, 

93 ) 

94 

95 def delete_func(queue_name: str) -> None: 

96 check_queue_name(queue_name) 

97 

98 service_resource.get_queue_by_name( 

99 QueueName=queue_name, 

100 ).delete() 

101 

102 def list_func(_: None) -> Iterable[tuple[str, Queue]]: 

103 for queue in ( 

104 service_resource.queues.filter( 

105 QueueNamePrefix=queue_name_prefix, 

106 ) 

107 if queue_name_prefix 

108 else service_resource.queues.all() 

109 ): 

110 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1] 

111 yield queue_name, ( 

112 get_queue_json(queue) if json_only 

113 else queue 

114 ) 

115 

116 return CRUDLDict[str, Queue]( 

117 create_func=create_func, 

118 read_func=read_func, 

119 update_func=update_func, 

120 delete_func=delete_func, 

121 list_func=list_func, 

122 ) 

123 

124 

125MESSAGE_BATCH_SIZE = 10 

126 

127 

128def __encode_body( 

129 body: str, 

130 *, 

131 encoding: Literal["gzip", "zstd"] | None = None, 

132) -> str: 

133 match encoding: 

134 case "gzip": 

135 return compress(body, gzip.compress) 

136 case "zstd": 

137 return compress(body, zstd.compress) 

138 case None: 

139 return body 

140 case _: 

141 raise ValueError 

142 

143 

144def send_messages( 

145 queue: Queue, 

146 messages: Iterable[JsonDict], 

147 group: str | None = None, 

148 *, 

149 encoding: Literal["gzip", "zstd"] | None = None, 

150) -> Iterable[JsonDict]: 

151 batch_id = str(uuid4()) 

152 

153 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX) 

154 if fifo and not group: 

155 raise ValueError 

156 

157 for message_batch in partition_all( 

158 MESSAGE_BATCH_SIZE, 

159 ( 

160 (f"{batch_id}_{i}", message_data) 

161 for i, message_data in enumerate(messages) 

162 ), 

163 ): 

164 response: JsonDict = queue.send_messages(Entries=[ 

165 dict( 

166 Id=message_id, 

167 MessageBody=__encode_body( 

168 json.dumps(message_data), 

169 encoding=encoding, 

170 ), 

171 ) | ( 

172 dict( 

173 MessageAttributes={ 

174 "ContentEncoding": { 

175 "StringValue": encoding, 

176 "DataType": "String", 

177 }, 

178 }, 

179 ) 

180 if encoding else {} 

181 ) | ( 

182 dict( 

183 MessageDeduplicationId=message_id, 

184 MessageGroupId=group, 

185 ) 

186 if fifo else {} 

187 ) 

188 for message_id, message_data in message_batch 

189 ]) 

190 

191 yield from response.get("Successful", []) 

192 yield from response.get("Failed", []) 

193 

194 

195class FifoRouter: 

196 def __init__( 

197 self, 

198 *, 

199 service_resource: ServiceResource | None = None, 

200 queue_name_prefix: str, 

201 default_queue_base_name: str, 

202 ) -> None: 

203 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict( 

204 service_resource=service_resource, 

205 queue_name_prefix=queue_name_prefix, 

206 ) 

207 

208 self.__default_queue: Queue = self.__resource_dict[ 

209 queue_name_prefix + default_queue_base_name + FIFO_QUEUE_NAME_SUFFIX 

210 ] 

211 

212 self.__queue_name_prefix = queue_name_prefix 

213 

214 queue_name_prefix_len = len(queue_name_prefix) 

215 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX) 

216 self.__queues: dict[str, Queue] = { 

217 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue 

218 for queue_name, queue in self.__resource_dict.items() 

219 } 

220 

221 def register_queue( 

222 self, 

223 queue_base_name: str, 

224 groups: Iterable[str] | None = None, 

225 *, 

226 create: bool = True, 

227 ) -> None: 

228 queue_name = self.__queue_name_prefix + queue_base_name + FIFO_QUEUE_NAME_SUFFIX 

229 

230 if queue_name not in self.__resource_dict: 

231 if create: 

232 self.__resource_dict[queue_name] = {} 

233 else: 

234 raise KeyError 

235 

236 queue = self.__resource_dict[queue_name] 

237 self.__queues.update(dict.fromkeys(groups or [queue_base_name], queue)) 

238 

239 def send_messages( 

240 self, 

241 messages: Iterable[JsonDict], 

242 group: str, 

243 *, 

244 encoding: Literal["gzip", "zstd"] | None = None, 

245 ) -> Iterable[JsonDict]: 

246 yield from send_messages( 

247 self.__queues.get(group, self.__default_queue), 

248 messages, 

249 group, 

250 encoding=encoding, 

251 )