Coverage for arrakis/client.py: 95.7%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-13 15:09 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE 

7 

8"""Client-based API access.""" 

9 

10import logging 

11from collections.abc import Generator 

12 

13import numpy 

14import pyarrow 

15from pyarrow import flight 

16from pyarrow.flight import connect 

17 

18from . import constants 

19from .block import SeriesBlock, combine_blocks, concatenate_blocks, time_as_ns 

20from .channel import Channel 

21from .flight import ( 

22 MultiEndpointStream, 

23 RequestType, 

24 RequestValidator, 

25 create_descriptor, 

26 parse_url, 

27) 

28from .mux import Muxer 

29 

30logger = logging.getLogger("arrakis") 

31 

32DataTypeLike = str | list[str] | type | list[type] | numpy.dtype | list[numpy.dtype] 

33 

34 

35def get_flight_info( 

36 client: flight.FlightClient, descriptor: flight.FlightDescriptor 

37) -> flight.FlightInfo: 

38 flight_info = client.get_flight_info(descriptor) 

39 logger.debug( 

40 "flight info received, %g endpoints identified", len(flight_info.endpoints) 

41 ) 

42 return flight_info 

43 

44 

45class Client: 

46 """Client to fetch or publish timeseries. 

47 

48 Parameters 

49 ---------- 

50 url : str, optional 

51 The URL to connect to. 

52 If the URL is not set, connect to a default server 

53 or one set by ARRAKIS_SERVER. 

54 

55 """ 

56 

57 def __init__(self, url: str | None = None): 

58 self.initial_url = parse_url(url) 

59 logger.debug("initial url: %s", self.initial_url) 

60 self._validator = RequestValidator() 

61 

62 def find( 

63 self, 

64 pattern: str = constants.DEFAULT_MATCH, 

65 data_type: DataTypeLike | None = None, 

66 min_rate: int | None = constants.MIN_SAMPLE_RATE, 

67 max_rate: int | None = constants.MAX_SAMPLE_RATE, 

68 publisher: str | list[str] | None = None, 

69 ) -> Generator[Channel, None, None]: 

70 """Find channels matching a set of conditions 

71 

72 Parameters 

73 ---------- 

74 pattern : str, optional 

75 Channel pattern to match channels with, using regular expressions. 

76 data_type : numpy.dtype-like | list[numpy.dtype-like], optional 

77 If set, find all channels with these data types. 

78 min_rate : int, optional 

79 Minimum sampling rate for channels. 

80 max_rate : int, optional 

81 Maximum sampling rate for channels. 

82 publisher : str | list[str], optional 

83 If set, find all channels associated with these publishers. 

84 

85 Yields 

86 ------- 

87 Channel 

88 Channel objects for all channels matching query. 

89 

90 """ 

91 data_type = _parse_data_types(data_type) 

92 if min_rate is None: 

93 min_rate = constants.MIN_SAMPLE_RATE 

94 if max_rate is None: 

95 max_rate = constants.MAX_SAMPLE_RATE 

96 if publisher is None: 

97 publisher = [] 

98 elif isinstance(publisher, str): 

99 publisher = [publisher] 

100 

101 descriptor = create_descriptor( 

102 RequestType.Find, 

103 pattern=pattern, 

104 data_type=data_type, 

105 min_rate=min_rate, 

106 max_rate=max_rate, 

107 publisher=publisher, 

108 validator=self._validator, 

109 ) 

110 with connect(self.initial_url) as client: 

111 yield from self._stream_channel_metadata(client, descriptor) 

112 

113 def count( 

114 self, 

115 pattern: str = constants.DEFAULT_MATCH, 

116 data_type: DataTypeLike | None = None, 

117 min_rate: int | None = constants.MIN_SAMPLE_RATE, 

118 max_rate: int | None = constants.MAX_SAMPLE_RATE, 

119 publisher: str | list[str] | None = None, 

120 ) -> int: 

121 """Count channels matching a set of conditions 

122 

123 Parameters 

124 ---------- 

125 pattern : str, optional 

126 Channel pattern to match channels with, using regular expressions. 

127 data_type : numpy.dtype-like | list[numpy.dtype-like], optional 

128 If set, find all channels with these data types. 

129 min_rate : int, optional 

130 The minimum sampling rate for channels. 

131 max_rate : int, optional 

132 The maximum sampling rate for channels. 

133 publisher : str | list[str], optional 

134 If set, find all channels associated with these publishers. 

135 

136 Returns 

137 ------- 

138 int 

139 The number of channels matching query. 

140 

141 """ 

142 data_type = _parse_data_types(data_type) 

143 if min_rate is None: 

144 min_rate = constants.MIN_SAMPLE_RATE 

145 if max_rate is None: 

146 max_rate = constants.MAX_SAMPLE_RATE 

147 if publisher is None: 

148 publisher = [] 

149 elif isinstance(publisher, str): 

150 publisher = [publisher] 

151 

152 descriptor = create_descriptor( 

153 RequestType.Count, 

154 pattern=pattern, 

155 data_type=data_type, 

156 min_rate=min_rate, 

157 max_rate=max_rate, 

158 publisher=publisher, 

159 validator=self._validator, 

160 ) 

161 count = 0 

162 with connect(self.initial_url) as client: 

163 flight_info = get_flight_info(client, descriptor) 

164 with MultiEndpointStream(flight_info.endpoints, client) as stream: 

165 for data in stream.unpack(): 

166 count += data["count"] 

167 return count 

168 

169 def describe(self, channels: list[str]) -> dict[str, Channel]: 

170 """Get channel metadata for channels requested 

171 

172 Parameters 

173 ---------- 

174 channels : list[str] 

175 List of channels to request. 

176 

177 Returns 

178 ------- 

179 dict[str, Channel] 

180 Mapping of channel names to channel metadata. 

181 

182 """ 

183 descriptor = create_descriptor( 

184 RequestType.Describe, channels=channels, validator=self._validator 

185 ) 

186 with connect(self.initial_url) as client: 

187 return { 

188 channel.name: channel 

189 for channel in self._stream_channel_metadata(client, descriptor) 

190 } 

191 

192 def stream( 

193 self, 

194 channels: list[str], 

195 start: float | None = None, 

196 end: float | None = None, 

197 ) -> Generator[SeriesBlock, None, None]: 

198 """Stream live or offline timeseries data 

199 

200 Parameters 

201 ---------- 

202 channels : list[str] 

203 List of channels to request. 

204 start : float, optional 

205 GPS start time, in seconds. 

206 end : float, optional 

207 GPS end time, in seconds. 

208 

209 Yields 

210 ------ 

211 SeriesBlock 

212 Dictionary-like object containing all requested channel data. 

213 

214 Setting neither start nor end begins a live stream starting 

215 from now. 

216 

217 """ 

218 start_ns = time_as_ns(start) if start is not None else None 

219 end_ns = time_as_ns(end) if end is not None else None 

220 metadata: dict[str, Channel] = {} 

221 schemas: dict[str, pyarrow.Schema] = {} 

222 

223 with connect(self.initial_url) as client: 

224 descriptor = create_descriptor( 

225 RequestType.Stream, 

226 channels=channels, 

227 start=start_ns, 

228 end=end_ns, 

229 validator=self._validator, 

230 ) 

231 flight_info = get_flight_info(client, descriptor) 

232 # use the serialized endpoints as the mux keys 

233 keys = [e.serialize() for e in flight_info.endpoints] 

234 mux: Muxer = Muxer(keys=keys) 

235 with MultiEndpointStream(flight_info.endpoints, client) as stream: 

236 for chunk, endpoint in stream: 

237 time = chunk.data.column("time").to_numpy()[0] 

238 mux.push(time, endpoint.serialize(), chunk.data) 

239 # FIXME: how do we handle stream drop-outs that result 

240 # in timeouts in the muxer that result in null data in 

241 # the mux pull? 

242 for mux_data in mux.pull(): 

243 blocks = [] 

244 # update channel metadata if needed 

245 for key, batch in mux_data.items(): 

246 if ( 

247 key not in schemas 

248 or schemas[key].metadata != batch.schema.metadata 

249 ): 

250 channel_fields: list[pyarrow.field] = list( 

251 batch.schema 

252 )[1:] 

253 for field in channel_fields: 

254 metadata[field.name] = Channel.from_field(field) 

255 schemas[key] = batch.schema 

256 

257 blocks.append( 

258 SeriesBlock.from_column_batch(batch, metadata) 

259 ) 

260 

261 # generate synchronized blocks 

262 yield combine_blocks(*blocks) 

263 

264 def fetch( 

265 self, 

266 channels: list[str], 

267 start: float, 

268 end: float, 

269 ) -> SeriesBlock: 

270 """Fetch timeseries data 

271 

272 Parameters 

273 ---------- 

274 channels : list[str] 

275 List of channels to request. 

276 start : float 

277 GPS start time, in seconds. 

278 end : float 

279 GPS end time, in seconds. 

280 

281 Returns 

282 ------- 

283 SeriesBlock 

284 Dictionary-like object containing all requested channel data. 

285 

286 """ 

287 return concatenate_blocks(*self.stream(channels, start, end)) 

288 

289 def _stream_channel_metadata( 

290 self, 

291 client: flight.FlightClient, 

292 descriptor: flight.FlightDescriptor, 

293 ) -> Generator[Channel, None, None]: 

294 """stream channel metadata.""" 

295 flight_info = get_flight_info(client, descriptor) 

296 with MultiEndpointStream(flight_info.endpoints, client) as stream: 

297 for channel_meta in stream.unpack(): 

298 yield Channel( 

299 channel_meta["channel"], 

300 data_type=numpy.dtype(channel_meta["data_type"]), 

301 sample_rate=channel_meta["sample_rate"], 

302 publisher=channel_meta["publisher"], 

303 partition_id=channel_meta["partition_id"], 

304 ) 

305 

306 

307def _parse_data_types( 

308 data_types: DataTypeLike | None, 

309) -> list[str]: 

310 """Parse numpy-like data types to be JSON-serializable.""" 

311 if data_types is None: 

312 return [] 

313 if isinstance(data_types, (str, type, numpy.dtype)): 

314 return [numpy.dtype(data_types).name] 

315 return [numpy.dtype(dtype).name for dtype in data_types]