Coverage for /Users/eugene/Development/robotnikmq/robotnikmq/rpc_server.py: 53%

180 statements  

« prev     ^ index     » next       coverage.py v7.3.4, created at 2023-12-26 19:13 -0500

1from dataclasses import dataclass 

2from inspect import signature, Parameter 

3from json import loads as _from_json 

4from traceback import format_exc 

5from socket import gethostname 

6from typing import Optional, Callable, Union, Any, Dict, Tuple, List, TypedDict, Type 

7from typing import get_type_hints, get_origin, get_args 

8from uuid import uuid4 as uuid, UUID 

9 

10from pika import BasicProperties 

11from pika.exceptions import AMQPError, ChannelError, AMQPConnectionError 

12from tenacity import retry, wait_exponential, retry_if_exception_type 

13from typeguard import typechecked 

14 

15from robotnikmq.config import RobotnikConfig 

16from robotnikmq.core import Robotnik, thread_name, valid_json 

17from robotnikmq.utils import to_json as _to_json 

18from robotnikmq.log import log 

19 

20 

21@typechecked 

22def _type_hint_str(typ: Any) -> str: 

23 if get_origin(typ) is Union: 

24 return f"Union[{','.join([_type_hint_str(t) for t in get_args(typ)])}]" 

25 return str(typ.__name__) 

26 

27@typechecked 

28class RpcErrorTypedDict(TypedDict): 

29 request_id: str 

30 type: str 

31 details: Union[None, str, Dict[str, Any]] 

32 

33@typechecked 

34@dataclass(frozen=True) 

35class RpcError: 

36 request_id: Union[str, UUID] 

37 details: Union[None, str, Dict[str, Any]] 

38 

39 @staticmethod 

40 def of( 

41 request_id: Union[str, UUID, None] = None, 

42 details: Union[None, str, Dict[str, Any]] = None, 

43 ) -> 'RpcError': 

44 return RpcError(request_id or uuid(), details) 

45 

46 def to_json(self) -> str: 

47 return _to_json(self.to_dict()) 

48 

49 def to_dict(self) -> RpcErrorTypedDict: 

50 return { 

51 "request_id": str(self.request_id), 

52 "type": "error", 

53 "details": self.details, 

54 } 

55 

56 @staticmethod 

57 def from_json(json_str: Union[str, bytes]) -> Optional['RpcError']: 

58 json_str = json_str if isinstance(json_str, str) else json_str.decode() 

59 log.debug(json_str) 

60 if valid_json(json_str): 

61 data = _from_json(json_str) 

62 if all(k in data for k in {"request_id", "type", "details"}): 

63 return RpcError.of(request_id=data["request_id"], details=data["details"]) 

64 return None 

65 

66 

67@typechecked 

68class RpcResponseTypedDict(TypedDict): 

69 request_id: str 

70 type: str 

71 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]] 

72 

73 

74@typechecked 

75@dataclass(frozen=True) 

76class RpcResponse: 

77 request_id: Union[str, UUID] 

78 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]] 

79 

80 @staticmethod 

81 def of( 

82 request_id: Union[str, UUID, None] = None, 

83 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]] = None, 

84 ) -> 'RpcResponse': 

85 return RpcResponse(request_id or uuid(), data) 

86 

87 def to_dict(self) -> RpcResponseTypedDict: 

88 return { 

89 "request_id": str(self.request_id), 

90 "type": "response", 

91 "data": self.data, 

92 } 

93 

94 def to_json(self) -> str: 

95 return _to_json(self.to_dict()) 

96 

97 @staticmethod 

98 def from_json(json_str: Union[str, bytes]) -> Optional["RpcResponse"]: 

99 json_str = json_str if isinstance(json_str, str) else json_str.decode() 

100 if valid_json(json_str): 

101 data = _from_json(json_str) 

102 if all(k in data for k in ("request_id", "type", "data")): 

103 return RpcResponse.of(request_id=data["request_id"], data=data["data"]) 

104 return None 

105 

106 

107class RpcServer(Robotnik): 

108 @typechecked 

109 def __init__( 

110 self, 

111 config: Optional[RobotnikConfig] = None, 

112 meta_queue_prefix: Optional[str] = None, 

113 docs_queue_suffix: Optional[str] = None, 

114 only_once: bool = False, 

115 ): 

116 super().__init__(config=config) 

117 self._callbacks: Dict[str, Callable] = {} 

118 self.meta_queue_prefix = meta_queue_prefix or gethostname() 

119 self.docs_queue_suffix = docs_queue_suffix or ".__doc__" 

120 # Typically used for testing, implies server should stop after 1 response 

121 self.only_once = only_once 

122 

123 @typechecked 

124 def _register_docs(self, queue: str, callback: Callable) -> None: 

125 self.channel.queue_declare( 

126 queue=queue + self.docs_queue_suffix, exclusive=False 

127 ) 

128 

129 @typechecked 

130 def docs_callback(_, method, props: BasicProperties, __) -> None: 

131 req_id = props.correlation_id or uuid() 

132 response = RpcResponse.of( 

133 req_id, 

134 data={ 

135 "rpc_queue": queue, 

136 "inputs": self._get_input_type_strings(queue), 

137 "returns": self._get_return_type_str(queue), 

138 "description": callback.__doc__, 

139 }, 

140 ) 

141 self.channel.basic_publish( 

142 exchange="", 

143 routing_key=props.reply_to or "", 

144 properties=BasicProperties(correlation_id=props.correlation_id), 

145 body=response.to_json(), 

146 ) 

147 self.channel.basic_ack(delivery_tag=method.delivery_tag) 

148 

149 self.channel.basic_consume( 

150 queue=queue + self.docs_queue_suffix, 

151 on_message_callback=docs_callback, 

152 auto_ack=False, 

153 ) 

154 

155 @typechecked 

156 def _get_defaults(self, queue: str) -> Dict: 

157 params = signature(self._callbacks[queue]).parameters 

158 return { 

159 p: params[p].default 

160 for p in params 

161 if params[p].default is not Parameter.empty 

162 } 

163 

164 @typechecked 

165 def _get_input_types(self, queue: str) -> Dict: 

166 return { 

167 k: v 

168 for k, v in get_type_hints(self._callbacks[queue]).items() 

169 if k != "return" 

170 } 

171 

172 @typechecked 

173 def _get_input_type_strings(self, queue: str) -> Dict: 

174 return { 

175 k: _type_hint_str(v) 

176 for k, v in get_type_hints(self._callbacks[queue]).items() 

177 if k != "return" 

178 } 

179 

180 @typechecked 

181 def _get_return_type_str(self, queue: str) -> Any: 

182 return _type_hint_str(get_type_hints(self._callbacks[queue])["return"]) 

183 

184 @typechecked 

185 @staticmethod 

186 def _is_optional(arg_type: Any) -> bool: 

187 return get_origin(arg_type) is Union and type(None) in get_args(arg_type) 

188 

189 @typechecked 

190 @staticmethod 

191 def _valid_arg(arg_value: Any, arg_type: Any) -> bool: 

192 if arg_type is Any: 

193 return True 

194 if get_origin(arg_type) is Union: 

195 if (type(None) in get_args(arg_type)) and ( 

196 arg_value is None or arg_value == {} 

197 ): # Optional 

198 return True 

199 return any( 

200 RpcServer._valid_arg(arg_value, typ) for typ in get_args(arg_type) 

201 ) 

202 if get_origin(arg_type) is dict: 

203 key_type, val_type = get_args(arg_type) 

204 return all( 

205 RpcServer._valid_arg(key, key_type) for key in arg_value.keys() 

206 ) and all(RpcServer._valid_arg(val, val_type) for val in arg_value.values()) 

207 return isinstance(arg_value, arg_type) 

208 

209 def _valid_inputs(self, queue: str, inputs: Dict[str, Any]) -> Tuple[bool, Optional[str]]: 

210 inputs_with_defaults = {**self._get_defaults(queue), **inputs} 

211 for arg_name, arg_type in self._get_input_types(queue).items(): 

212 if arg_name not in inputs_with_defaults and not self._is_optional(arg_type): 

213 return False, f"Missing required argument {arg_name}" 

214 if arg_name in inputs_with_defaults and not self._valid_arg( 

215 inputs_with_defaults[arg_name], arg_type 

216 ): 

217 return False, f"Invalid type for {arg_name}" 

218 return True, None 

219 

220 @typechecked 

221 def register_rpc( 

222 self, queue: str, callback: Callable, register_docs: bool = True 

223 ) -> None: 

224 self.channel.queue_declare(queue=queue, exclusive=False) 

225 self._callbacks[queue] = callback 

226 if register_docs: 

227 self._register_docs(queue, callback) 

228 # TODO: servers should have an exclusive Queue for information about themselves 

229 

230 @typechecked 

231 def meta_callback(_, method, props: BasicProperties, body: bytes): 

232 req_id = props.correlation_id or uuid() 

233 with thread_name(req_id): 

234 self.log.debug("Request received") 

235 try: 

236 try: 

237 if valid_json(body.decode()): 

238 input_args: Dict = _from_json(body.decode()) 

239 self.log.debug(f"Input JSON is valid: {input_args}") 

240 valid_inputs, msg = self._valid_inputs(queue, input_args) 

241 if not valid_inputs: 

242 self.log.debug("Invalid input") 

243 response = RpcError.of(req_id, msg).to_json() 

244 elif not input_args: 

245 self.log.debug(f"Executing: {callback}") 

246 response = RpcResponse.of(req_id, callback()).to_json() 

247 else: 

248 self.log.debug( 

249 f"Executing: {callback} with inputs: {input_args}" 

250 ) 

251 response = RpcResponse.of( 

252 req_id, callback(**input_args) 

253 ).to_json() 

254 else: 

255 response = RpcError.of( 

256 req_id, "Input could not be decoded as JSON" 

257 ).to_json() 

258 except (AMQPError, ChannelError): 

259 raise # we want this kind of exception to be caught further down 

260 except Exception: # pylint: disable=W0703 

261 self.log.error( 

262 "An error has occurred during the execution of the RPC method" 

263 ) 

264 for line in format_exc().split("\n"): 

265 self.log.error(line) 

266 response = RpcError.of( 

267 request_id=req_id, 

268 details=f"There was an error " 

269 f"while processing the " 

270 f"request, please refer " 

271 f"to server log with " 

272 f"request ID: " 

273 f"{req_id}", 

274 ).to_json() 

275 self.log.debug(f"Response: {response}") 

276 self.channel.basic_publish( 

277 exchange="", 

278 routing_key=props.reply_to or "", 

279 properties=BasicProperties(correlation_id=props.correlation_id), 

280 body=response, 

281 ) 

282 self.channel.basic_ack(delivery_tag=method.delivery_tag) 

283 self.log.debug("Response sent and ack-ed") 

284 except (AMQPError, ChannelError): 

285 self.log.error( 

286 f"A RabbitMQ communication error has occurred while processing " 

287 f"Request ID: {req_id}" 

288 ) 

289 for line in format_exc().split("\n"): 

290 self.log.error(line) 

291 if self.only_once: 

292 self.channel.stop_consuming() 

293 

294 self.channel.basic_consume( 

295 queue=queue, on_message_callback=meta_callback, auto_ack=False 

296 ) 

297 

298 @retry( 

299 retry=retry_if_exception_type((AMQPConnectionError, OSError)), 

300 wait=wait_exponential(multiplier=1, min=3, max=30), 

301 ) 

302 @typechecked 

303 def run(self) -> None: 

304 try: 

305 self.channel.start_consuming() 

306 except KeyboardInterrupt: 

307 self.channel.stop_consuming() 

308 self.log.info("Shutting down server")