Coverage for arrakis/flight.py: 93.5%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-13 15:09 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE 

7 

8"""Arrow Flight utilities.""" 

9 

10from __future__ import annotations 

11 

12import concurrent.futures 

13import contextlib 

14import json 

15import logging 

16import os 

17import queue 

18import threading 

19from enum import IntEnum, auto 

20from importlib import resources 

21from typing import TYPE_CHECKING, Any, TypedDict 

22from unittest.mock import sentinel 

23from urllib.parse import urlparse 

24 

25import arrakis_schema 

26import jsonschema 

27from pyarrow import flight 

28from pyarrow.flight import connect 

29 

30from . import constants 

31 

32if TYPE_CHECKING: 

33 from collections.abc import Generator 

34 from datetime import timedelta 

35 

36 

37logger = logging.getLogger("arrakis") 

38 

39 

40EOS = sentinel.EOS 

41 

42 

43class RequestType(IntEnum): 

44 Stream = auto() 

45 Describe = auto() 

46 Find = auto() 

47 Count = auto() 

48 Publish = auto() 

49 Partition = auto() 

50 

51 

52class Request(TypedDict): 

53 request: str 

54 args: dict[str, Any] 

55 

56 

57class RequestValidator: 

58 """A validator for JSON-encoded requests.""" 

59 

60 def __init__(self) -> None: 

61 self._validators: dict[RequestType, jsonschema.Draft7Validator] = {} 

62 

63 # load generic descriptor schema 

64 resource = resources.files(arrakis_schema).joinpath("descriptor.json") 

65 with resources.as_file(resource) as path: 

66 schema = json.loads(path.read_text()) 

67 self._generic_validator = jsonschema.Draft7Validator(schema) 

68 

69 def validate(self, payload: Request) -> None: 

70 """Validate a JSON-encoded request. 

71 

72 Parameters 

73 ---------- 

74 payload : Request 

75 A dictionary with a 'request' and an 'args' key encoding 

76 the given Flight request. 

77 

78 Raises 

79 ------ 

80 ValidationError 

81 If the request does not match the expected schema. 

82 

83 """ 

84 self._generic_validator.validate(payload) 

85 request = RequestType[payload["request"]] 

86 

87 # load schema on demand 

88 if request not in self._validators: 

89 resource = resources.files(arrakis_schema).joinpath( 

90 f"{request.name.lower()}.json" 

91 ) 

92 with resources.as_file(resource) as path: 

93 schema = json.loads(path.read_text()) 

94 self._validators[request] = jsonschema.Draft7Validator(schema) 

95 

96 self._validators[request].validate(payload) 

97 

98 

99def parse_url(url: str | None) -> str: 

100 if url is None: 

101 url = os.getenv("ARRAKIS_SERVER", constants.DEFAULT_ARRAKIS_SERVER) 

102 assert url is not None, "ARRAKIS_SERVER not specified." 

103 parsed = urlparse(url, scheme="grpc") 

104 if parsed.scheme != "grpc": 

105 msg = f"invalid URL {url}. if scheme is specified, it must start with grpc://" 

106 raise ValueError(msg) 

107 return parsed.geturl() 

108 

109 

110def create_command( 

111 request_type: RequestType, *, validator: RequestValidator, **kwargs 

112) -> bytes: 

113 """Create a Flight command containing a JSON-encoded request. 

114 

115 Parameters 

116 ---------- 

117 request_type : RequestType 

118 The type of request. 

119 validator : RequestValidator 

120 A validator to validate that the command matches the expected schema. 

121 **kwargs : dict, optional 

122 Extra arguments corresponding to the specific request. 

123 

124 Returns 

125 ------- 

126 bytes 

127 The JSON-encoded request. 

128 

129 Raises 

130 ------ 

131 ValidationError 

132 If the request does not match the expected schema. 

133 

134 """ 

135 cmd: Request = { 

136 "request": request_type.name, 

137 "args": kwargs, 

138 } 

139 validator.validate(cmd) 

140 return json.dumps(cmd).encode("utf-8") 

141 

142 

143def create_descriptor( 

144 request_type: RequestType, *, validator: RequestValidator, **kwargs 

145) -> flight.FlightDescriptor: 

146 """Create a Flight descriptor given a request. 

147 

148 Parameters 

149 ---------- 

150 request_type : RequestType 

151 The type of request. 

152 validator : RequestValidator 

153 A validator to validate that the command matches the expected schema. 

154 **kwargs : dict, optional 

155 Extra arguments corresponding to the specific request. 

156 

157 Returns 

158 ------- 

159 flight.FlightDescriptor 

160 A Flight Descriptor containing the request. 

161 

162 Raises 

163 ------ 

164 ValidationError 

165 If the request does not match the expected schema. 

166 

167 """ 

168 cmd = create_command(request_type, validator=validator, **kwargs) 

169 return flight.FlightDescriptor.for_command(cmd) 

170 

171 

172def parse_command( 

173 cmd: bytes, *, validator: RequestValidator 

174) -> tuple[RequestType, dict]: 

175 """Parse a Flight command into a request. 

176 

177 Parameters 

178 ---------- 

179 cmd : bytes 

180 The JSON-encoded request. 

181 validator : RequestValidator 

182 A validator to validate that the command matches the expected schema. 

183 

184 Returns 

185 ------- 

186 request_type : RequestType 

187 The type of request. 

188 kwargs : dict 

189 Arguments corresponding to the specific request. 

190 

191 Raises 

192 ------ 

193 JSONDecodeError 

194 If the command does not decode to valid JSON. 

195 ValidationError 

196 If the request does not match the expected schema. 

197 

198 """ 

199 try: 

200 parsed = json.loads(cmd.decode("utf-8")) 

201 except json.JSONDecodeError as e: 

202 msg = "Command does not decode to valid JSON" 

203 raise json.JSONDecodeError(msg, e.doc, e.pos) from e 

204 else: 

205 validator.validate(parsed) 

206 return RequestType[parsed["request"]], parsed["args"] 

207 

208 

209class MultiEndpointStream(contextlib.AbstractContextManager): 

210 """Multi-threaded Arrow Flight endpoint stream iterator context manager 

211 

212 Given a list of endpoints, connect to all of them in parallel and 

213 stream data from them all interleaved. 

214 

215 """ 

216 

217 def __init__( 

218 self, 

219 endpoints: list[flight.FlightEndpoint], 

220 initial_client: flight.FlightClient, 

221 ): 

222 """initialize with list of endpoints and an reusable flight client""" 

223 self.endpoints = endpoints 

224 self.initial_client = initial_client 

225 self.q: queue.SimpleQueue = queue.SimpleQueue() 

226 self.quit_event: threading.Event = threading.Event() 

227 self.executor = concurrent.futures.ThreadPoolExecutor( 

228 max_workers=len(self.endpoints), 

229 ) 

230 self.threads_done = {endpoint.serialize(): False for endpoint in endpoints} 

231 self.futures: list[concurrent.futures.Future] | None = None 

232 

233 def _execute_endpoint(self, endpoint: flight.FlightEndpoint): 

234 logger.debug("endpoint: %s", endpoint) 

235 # FIXME: endpoints can contain multiple locations from which 

236 # the ticket can be served, considered as data replicas. we 

237 # should cycle through backup locations if there are 

238 # connection issues with the primary one. 

239 location = endpoint.locations[0] 

240 # if an endpoint is specified with the special location: 

241 # "arrow-flight-reuse-connection://?" 

242 # then the initial client connection will be reused. 

243 # see: https://arrow.apache.org/docs/format/Flight.html#connection-reuse 

244 scheme = urlparse(location.uri.decode()).scheme 

245 if scheme == "arrow-flight-reuse-connection": 

246 context: contextlib.AbstractContextManager[flight.FlightClient] 

247 context = contextlib.nullcontext(self.initial_client) 

248 else: 

249 context = connect(location) 

250 with context as client: 

251 try: 

252 for chunk in client.do_get(endpoint.ticket): 

253 if self.quit_event.is_set(): 

254 break 

255 self.q.put((chunk, endpoint)) 

256 finally: 

257 self.q.put((EOS, endpoint)) 

258 

259 def __iter__( 

260 self, 

261 timeout: timedelta = constants.DEFAULT_QUEUE_TIMEOUT, 

262 ) -> Generator[ 

263 flight.FlightStreamReader 

264 | tuple[flight.FlightStreamReader, flight.FlightEndpoint], 

265 None, 

266 None, 

267 ]: 

268 """Execute the streams and yield the results 

269 

270 Yielded results are a tuple of the data chunk, and the 

271 endpoint it came from. 

272 

273 The timeout is expected to be a timedelta object. 

274 

275 """ 

276 self.futures = [ 

277 self.executor.submit(self._execute_endpoint, endpoint) 

278 for endpoint in self.endpoints 

279 ] 

280 

281 while not all(self.threads_done.values()): 

282 try: 

283 data, endpoint = self.q.get(block=True, timeout=timeout.total_seconds()) 

284 except queue.Empty: 

285 pass 

286 else: 

287 if data is EOS: 

288 self.threads_done[endpoint.serialize()] = True 

289 else: 

290 yield data, endpoint 

291 for future in self.futures: 

292 if future.done() and future.exception(): 

293 self.quit_event.set() 

294 

295 stream = __iter__ 

296 

297 def unpack(self): 

298 """Unpack stream data into individual elements""" 

299 for chunk, _ in self: 

300 yield from chunk.data.to_pylist() 

301 

302 def close(self): 

303 """close all streams""" 

304 self.quit_event.set() 

305 if self.futures is not None: 

306 for f in self.futures: 

307 # re-raise exceptions to the client, returning 

308 # user-friendly Flight-specific errors when relevant 

309 try: 

310 f.result() 

311 except flight.FlightError as e: 

312 # NOTE: this strips the original message of everything 

313 # besides the original error message raised by the server 

314 msg = e.args[0].partition(" Detail:")[0] 

315 raise type(e)(msg, e.extra_info) from None 

316 

317 self.executor.shutdown(cancel_futures=True) 

318 self.futures = None 

319 

320 def __exit__(self, exc_type, exc_value, traceback): 

321 self.close()