Coverage for arrakis/flight.py: 93.5%
123 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE
8"""Arrow Flight utilities."""
10from __future__ import annotations
12import concurrent.futures
13import contextlib
14import json
15import logging
16import os
17import queue
18import threading
19from enum import IntEnum, auto
20from importlib import resources
21from typing import TYPE_CHECKING, Any, TypedDict
22from unittest.mock import sentinel
23from urllib.parse import urlparse
25import arrakis_schema
26import jsonschema
27from pyarrow import flight
28from pyarrow.flight import connect
30from . import constants
32if TYPE_CHECKING:
33 from collections.abc import Generator
34 from datetime import timedelta
37logger = logging.getLogger("arrakis")
40EOS = sentinel.EOS
43class RequestType(IntEnum):
44 Stream = auto()
45 Describe = auto()
46 Find = auto()
47 Count = auto()
48 Publish = auto()
49 Partition = auto()
52class Request(TypedDict):
53 request: str
54 args: dict[str, Any]
57class RequestValidator:
58 """A validator for JSON-encoded requests."""
60 def __init__(self) -> None:
61 self._validators: dict[RequestType, jsonschema.Draft7Validator] = {}
63 # load generic descriptor schema
64 resource = resources.files(arrakis_schema).joinpath("descriptor.json")
65 with resources.as_file(resource) as path:
66 schema = json.loads(path.read_text())
67 self._generic_validator = jsonschema.Draft7Validator(schema)
69 def validate(self, payload: Request) -> None:
70 """Validate a JSON-encoded request.
72 Parameters
73 ----------
74 payload : Request
75 A dictionary with a 'request' and an 'args' key encoding
76 the given Flight request.
78 Raises
79 ------
80 ValidationError
81 If the request does not match the expected schema.
83 """
84 self._generic_validator.validate(payload)
85 request = RequestType[payload["request"]]
87 # load schema on demand
88 if request not in self._validators:
89 resource = resources.files(arrakis_schema).joinpath(
90 f"{request.name.lower()}.json"
91 )
92 with resources.as_file(resource) as path:
93 schema = json.loads(path.read_text())
94 self._validators[request] = jsonschema.Draft7Validator(schema)
96 self._validators[request].validate(payload)
99def parse_url(url: str | None) -> str:
100 if url is None:
101 url = os.getenv("ARRAKIS_SERVER", constants.DEFAULT_ARRAKIS_SERVER)
102 assert url is not None, "ARRAKIS_SERVER not specified."
103 parsed = urlparse(url, scheme="grpc")
104 if parsed.scheme != "grpc":
105 msg = f"invalid URL {url}. if scheme is specified, it must start with grpc://"
106 raise ValueError(msg)
107 return parsed.geturl()
110def create_command(
111 request_type: RequestType, *, validator: RequestValidator, **kwargs
112) -> bytes:
113 """Create a Flight command containing a JSON-encoded request.
115 Parameters
116 ----------
117 request_type : RequestType
118 The type of request.
119 validator : RequestValidator
120 A validator to validate that the command matches the expected schema.
121 **kwargs : dict, optional
122 Extra arguments corresponding to the specific request.
124 Returns
125 -------
126 bytes
127 The JSON-encoded request.
129 Raises
130 ------
131 ValidationError
132 If the request does not match the expected schema.
134 """
135 cmd: Request = {
136 "request": request_type.name,
137 "args": kwargs,
138 }
139 validator.validate(cmd)
140 return json.dumps(cmd).encode("utf-8")
143def create_descriptor(
144 request_type: RequestType, *, validator: RequestValidator, **kwargs
145) -> flight.FlightDescriptor:
146 """Create a Flight descriptor given a request.
148 Parameters
149 ----------
150 request_type : RequestType
151 The type of request.
152 validator : RequestValidator
153 A validator to validate that the command matches the expected schema.
154 **kwargs : dict, optional
155 Extra arguments corresponding to the specific request.
157 Returns
158 -------
159 flight.FlightDescriptor
160 A Flight Descriptor containing the request.
162 Raises
163 ------
164 ValidationError
165 If the request does not match the expected schema.
167 """
168 cmd = create_command(request_type, validator=validator, **kwargs)
169 return flight.FlightDescriptor.for_command(cmd)
172def parse_command(
173 cmd: bytes, *, validator: RequestValidator
174) -> tuple[RequestType, dict]:
175 """Parse a Flight command into a request.
177 Parameters
178 ----------
179 cmd : bytes
180 The JSON-encoded request.
181 validator : RequestValidator
182 A validator to validate that the command matches the expected schema.
184 Returns
185 -------
186 request_type : RequestType
187 The type of request.
188 kwargs : dict
189 Arguments corresponding to the specific request.
191 Raises
192 ------
193 JSONDecodeError
194 If the command does not decode to valid JSON.
195 ValidationError
196 If the request does not match the expected schema.
198 """
199 try:
200 parsed = json.loads(cmd.decode("utf-8"))
201 except json.JSONDecodeError as e:
202 msg = "Command does not decode to valid JSON"
203 raise json.JSONDecodeError(msg, e.doc, e.pos) from e
204 else:
205 validator.validate(parsed)
206 return RequestType[parsed["request"]], parsed["args"]
209class MultiEndpointStream(contextlib.AbstractContextManager):
210 """Multi-threaded Arrow Flight endpoint stream iterator context manager
212 Given a list of endpoints, connect to all of them in parallel and
213 stream data from them all interleaved.
215 """
217 def __init__(
218 self,
219 endpoints: list[flight.FlightEndpoint],
220 initial_client: flight.FlightClient,
221 ):
222 """initialize with list of endpoints and an reusable flight client"""
223 self.endpoints = endpoints
224 self.initial_client = initial_client
225 self.q: queue.SimpleQueue = queue.SimpleQueue()
226 self.quit_event: threading.Event = threading.Event()
227 self.executor = concurrent.futures.ThreadPoolExecutor(
228 max_workers=len(self.endpoints),
229 )
230 self.threads_done = {endpoint.serialize(): False for endpoint in endpoints}
231 self.futures: list[concurrent.futures.Future] | None = None
233 def _execute_endpoint(self, endpoint: flight.FlightEndpoint):
234 logger.debug("endpoint: %s", endpoint)
235 # FIXME: endpoints can contain multiple locations from which
236 # the ticket can be served, considered as data replicas. we
237 # should cycle through backup locations if there are
238 # connection issues with the primary one.
239 location = endpoint.locations[0]
240 # if an endpoint is specified with the special location:
241 # "arrow-flight-reuse-connection://?"
242 # then the initial client connection will be reused.
243 # see: https://arrow.apache.org/docs/format/Flight.html#connection-reuse
244 scheme = urlparse(location.uri.decode()).scheme
245 if scheme == "arrow-flight-reuse-connection":
246 context: contextlib.AbstractContextManager[flight.FlightClient]
247 context = contextlib.nullcontext(self.initial_client)
248 else:
249 context = connect(location)
250 with context as client:
251 try:
252 for chunk in client.do_get(endpoint.ticket):
253 if self.quit_event.is_set():
254 break
255 self.q.put((chunk, endpoint))
256 finally:
257 self.q.put((EOS, endpoint))
259 def __iter__(
260 self,
261 timeout: timedelta = constants.DEFAULT_QUEUE_TIMEOUT,
262 ) -> Generator[
263 flight.FlightStreamReader
264 | tuple[flight.FlightStreamReader, flight.FlightEndpoint],
265 None,
266 None,
267 ]:
268 """Execute the streams and yield the results
270 Yielded results are a tuple of the data chunk, and the
271 endpoint it came from.
273 The timeout is expected to be a timedelta object.
275 """
276 self.futures = [
277 self.executor.submit(self._execute_endpoint, endpoint)
278 for endpoint in self.endpoints
279 ]
281 while not all(self.threads_done.values()):
282 try:
283 data, endpoint = self.q.get(block=True, timeout=timeout.total_seconds())
284 except queue.Empty:
285 pass
286 else:
287 if data is EOS:
288 self.threads_done[endpoint.serialize()] = True
289 else:
290 yield data, endpoint
291 for future in self.futures:
292 if future.done() and future.exception():
293 self.quit_event.set()
295 stream = __iter__
297 def unpack(self):
298 """Unpack stream data into individual elements"""
299 for chunk, _ in self:
300 yield from chunk.data.to_pylist()
302 def close(self):
303 """close all streams"""
304 self.quit_event.set()
305 if self.futures is not None:
306 for f in self.futures:
307 # re-raise exceptions to the client, returning
308 # user-friendly Flight-specific errors when relevant
309 try:
310 f.result()
311 except flight.FlightError as e:
312 # NOTE: this strips the original message of everything
313 # besides the original error message raised by the server
314 msg = e.args[0].partition(" Detail:")[0]
315 raise type(e)(msg, e.extra_info) from None
317 self.executor.shutdown(cancel_futures=True)
318 self.futures = None
320 def __exit__(self, exc_type, exc_value, traceback):
321 self.close()