Coverage for arrakis/client.py: 95.7%
94 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE
8"""Client-based API access."""
10import logging
11from collections.abc import Generator
13import numpy
14import pyarrow
15from pyarrow import flight
16from pyarrow.flight import connect
18from . import constants
19from .block import SeriesBlock, combine_blocks, concatenate_blocks, time_as_ns
20from .channel import Channel
21from .flight import (
22 MultiEndpointStream,
23 RequestType,
24 RequestValidator,
25 create_descriptor,
26 parse_url,
27)
28from .mux import Muxer
30logger = logging.getLogger("arrakis")
32DataTypeLike = str | list[str] | type | list[type] | numpy.dtype | list[numpy.dtype]
35def get_flight_info(
36 client: flight.FlightClient, descriptor: flight.FlightDescriptor
37) -> flight.FlightInfo:
38 flight_info = client.get_flight_info(descriptor)
39 logger.debug(
40 "flight info received, %g endpoints identified", len(flight_info.endpoints)
41 )
42 return flight_info
45class Client:
46 """Client to fetch or publish timeseries.
48 Parameters
49 ----------
50 url : str, optional
51 The URL to connect to.
52 If the URL is not set, connect to a default server
53 or one set by ARRAKIS_SERVER.
55 """
57 def __init__(self, url: str | None = None):
58 self.initial_url = parse_url(url)
59 logger.debug("initial url: %s", self.initial_url)
60 self._validator = RequestValidator()
62 def find(
63 self,
64 pattern: str = constants.DEFAULT_MATCH,
65 data_type: DataTypeLike | None = None,
66 min_rate: int | None = constants.MIN_SAMPLE_RATE,
67 max_rate: int | None = constants.MAX_SAMPLE_RATE,
68 publisher: str | list[str] | None = None,
69 ) -> Generator[Channel, None, None]:
70 """Find channels matching a set of conditions
72 Parameters
73 ----------
74 pattern : str, optional
75 Channel pattern to match channels with, using regular expressions.
76 data_type : numpy.dtype-like | list[numpy.dtype-like], optional
77 If set, find all channels with these data types.
78 min_rate : int, optional
79 Minimum sampling rate for channels.
80 max_rate : int, optional
81 Maximum sampling rate for channels.
82 publisher : str | list[str], optional
83 If set, find all channels associated with these publishers.
85 Yields
86 -------
87 Channel
88 Channel objects for all channels matching query.
90 """
91 data_type = _parse_data_types(data_type)
92 if min_rate is None:
93 min_rate = constants.MIN_SAMPLE_RATE
94 if max_rate is None:
95 max_rate = constants.MAX_SAMPLE_RATE
96 if publisher is None:
97 publisher = []
98 elif isinstance(publisher, str):
99 publisher = [publisher]
101 descriptor = create_descriptor(
102 RequestType.Find,
103 pattern=pattern,
104 data_type=data_type,
105 min_rate=min_rate,
106 max_rate=max_rate,
107 publisher=publisher,
108 validator=self._validator,
109 )
110 with connect(self.initial_url) as client:
111 yield from self._stream_channel_metadata(client, descriptor)
113 def count(
114 self,
115 pattern: str = constants.DEFAULT_MATCH,
116 data_type: DataTypeLike | None = None,
117 min_rate: int | None = constants.MIN_SAMPLE_RATE,
118 max_rate: int | None = constants.MAX_SAMPLE_RATE,
119 publisher: str | list[str] | None = None,
120 ) -> int:
121 """Count channels matching a set of conditions
123 Parameters
124 ----------
125 pattern : str, optional
126 Channel pattern to match channels with, using regular expressions.
127 data_type : numpy.dtype-like | list[numpy.dtype-like], optional
128 If set, find all channels with these data types.
129 min_rate : int, optional
130 The minimum sampling rate for channels.
131 max_rate : int, optional
132 The maximum sampling rate for channels.
133 publisher : str | list[str], optional
134 If set, find all channels associated with these publishers.
136 Returns
137 -------
138 int
139 The number of channels matching query.
141 """
142 data_type = _parse_data_types(data_type)
143 if min_rate is None:
144 min_rate = constants.MIN_SAMPLE_RATE
145 if max_rate is None:
146 max_rate = constants.MAX_SAMPLE_RATE
147 if publisher is None:
148 publisher = []
149 elif isinstance(publisher, str):
150 publisher = [publisher]
152 descriptor = create_descriptor(
153 RequestType.Count,
154 pattern=pattern,
155 data_type=data_type,
156 min_rate=min_rate,
157 max_rate=max_rate,
158 publisher=publisher,
159 validator=self._validator,
160 )
161 count = 0
162 with connect(self.initial_url) as client:
163 flight_info = get_flight_info(client, descriptor)
164 with MultiEndpointStream(flight_info.endpoints, client) as stream:
165 for data in stream.unpack():
166 count += data["count"]
167 return count
169 def describe(self, channels: list[str]) -> dict[str, Channel]:
170 """Get channel metadata for channels requested
172 Parameters
173 ----------
174 channels : list[str]
175 List of channels to request.
177 Returns
178 -------
179 dict[str, Channel]
180 Mapping of channel names to channel metadata.
182 """
183 descriptor = create_descriptor(
184 RequestType.Describe, channels=channels, validator=self._validator
185 )
186 with connect(self.initial_url) as client:
187 return {
188 channel.name: channel
189 for channel in self._stream_channel_metadata(client, descriptor)
190 }
192 def stream(
193 self,
194 channels: list[str],
195 start: float | None = None,
196 end: float | None = None,
197 ) -> Generator[SeriesBlock, None, None]:
198 """Stream live or offline timeseries data
200 Parameters
201 ----------
202 channels : list[str]
203 List of channels to request.
204 start : float, optional
205 GPS start time, in seconds.
206 end : float, optional
207 GPS end time, in seconds.
209 Yields
210 ------
211 SeriesBlock
212 Dictionary-like object containing all requested channel data.
214 Setting neither start nor end begins a live stream starting
215 from now.
217 """
218 start_ns = time_as_ns(start) if start is not None else None
219 end_ns = time_as_ns(end) if end is not None else None
220 metadata: dict[str, Channel] = {}
221 schemas: dict[str, pyarrow.Schema] = {}
223 with connect(self.initial_url) as client:
224 descriptor = create_descriptor(
225 RequestType.Stream,
226 channels=channels,
227 start=start_ns,
228 end=end_ns,
229 validator=self._validator,
230 )
231 flight_info = get_flight_info(client, descriptor)
232 # use the serialized endpoints as the mux keys
233 keys = [e.serialize() for e in flight_info.endpoints]
234 mux: Muxer = Muxer(keys=keys)
235 with MultiEndpointStream(flight_info.endpoints, client) as stream:
236 for chunk, endpoint in stream:
237 time = chunk.data.column("time").to_numpy()[0]
238 mux.push(time, endpoint.serialize(), chunk.data)
239 # FIXME: how do we handle stream drop-outs that result
240 # in timeouts in the muxer that result in null data in
241 # the mux pull?
242 for mux_data in mux.pull():
243 blocks = []
244 # update channel metadata if needed
245 for key, batch in mux_data.items():
246 if (
247 key not in schemas
248 or schemas[key].metadata != batch.schema.metadata
249 ):
250 channel_fields: list[pyarrow.field] = list(
251 batch.schema
252 )[1:]
253 for field in channel_fields:
254 metadata[field.name] = Channel.from_field(field)
255 schemas[key] = batch.schema
257 blocks.append(
258 SeriesBlock.from_column_batch(batch, metadata)
259 )
261 # generate synchronized blocks
262 yield combine_blocks(*blocks)
264 def fetch(
265 self,
266 channels: list[str],
267 start: float,
268 end: float,
269 ) -> SeriesBlock:
270 """Fetch timeseries data
272 Parameters
273 ----------
274 channels : list[str]
275 List of channels to request.
276 start : float
277 GPS start time, in seconds.
278 end : float
279 GPS end time, in seconds.
281 Returns
282 -------
283 SeriesBlock
284 Dictionary-like object containing all requested channel data.
286 """
287 return concatenate_blocks(*self.stream(channels, start, end))
289 def _stream_channel_metadata(
290 self,
291 client: flight.FlightClient,
292 descriptor: flight.FlightDescriptor,
293 ) -> Generator[Channel, None, None]:
294 """stream channel metadata."""
295 flight_info = get_flight_info(client, descriptor)
296 with MultiEndpointStream(flight_info.endpoints, client) as stream:
297 for channel_meta in stream.unpack():
298 yield Channel(
299 channel_meta["channel"],
300 data_type=numpy.dtype(channel_meta["data_type"]),
301 sample_rate=channel_meta["sample_rate"],
302 publisher=channel_meta["publisher"],
303 partition_id=channel_meta["partition_id"],
304 )
307def _parse_data_types(
308 data_types: DataTypeLike | None,
309) -> list[str]:
310 """Parse numpy-like data types to be JSON-serializable."""
311 if data_types is None:
312 return []
313 if isinstance(data_types, (str, type, numpy.dtype)):
314 return [numpy.dtype(data_types).name]
315 return [numpy.dtype(dtype).name for dtype in data_types]