Coverage for arrakis/publish.py: 88.5%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-08-13 15:09 -0700

1# Copyright (c) 2022, California Institute of Technology and contributors 

2# 

3# You should have received a copy of the licensing terms for this 

4# software included in the file "LICENSE" located in the top-level 

5# directory of this package. If you did not, you can view a copy at 

6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE 

7 

8"""Publisher API.""" 

9 

10from __future__ import annotations 

11 

12import contextlib 

13import logging 

14from typing import TYPE_CHECKING, Literal 

15 

16import pyarrow 

17from pyarrow.flight import connect 

18 

19from . import constants 

20from .client import Client 

21from .flight import ( 

22 MultiEndpointStream, 

23 RequestType, 

24 RequestValidator, 

25 create_descriptor, 

26 parse_url, 

27) 

28 

29try: 

30 from confluent_kafka import Producer 

31except ImportError: 

32 HAS_KAFKA = False 

33else: 

34 HAS_KAFKA = True 

35 

36if TYPE_CHECKING: 

37 from collections.abc import Iterable 

38 from datetime import timedelta 

39 

40 from .block import SeriesBlock 

41 from .channel import Channel 

42 

43 

44logger = logging.getLogger("arrakis") 

45 

46 

47def serialize_batch(batch: pyarrow.RecordBatch): 

48 """Serialize a record batch to bytes. 

49 

50 Parameters 

51 ---------- 

52 batch : pyarrow.RecordBatch 

53 The batch to serialize. 

54 

55 Returns 

56 ------- 

57 bytes 

58 The serialized buffer. 

59 

60 """ 

61 sink = pyarrow.BufferOutputStream() 

62 with pyarrow.ipc.new_stream(sink, batch.schema) as writer: 

63 writer.write_batch(batch) 

64 return sink.getvalue() 

65 

66 

67def channel_to_dtype_name(channel: Channel) -> str: 

68 """Given a channel, return the data type's name.""" 

69 assert channel.data_type is not None 

70 if isinstance(channel.data_type, str): 

71 return channel.data_type 

72 return channel.data_type.name 

73 

74 

75class Publisher: 

76 """Publisher to publish timeseries to Arrakis service. 

77 

78 Parameters 

79 ---------- 

80 id : str 

81 Publisher ID string. 

82 url : str 

83 Initial Flight URL to connect to. 

84 

85 """ 

86 

87 def __init__(self, publisher_id: str, url: str | None = None): 

88 if not HAS_KAFKA: 

89 msg = ( 

90 "Publishing requires confluent-kafka to be installed." 

91 "This is provided by the 'publish' extra or it can be " 

92 "installed manually through pip or conda." 

93 ) 

94 raise ImportError(msg) 

95 

96 self.publisher_id = publisher_id 

97 self.initial_url = parse_url(url) 

98 

99 self.channels: dict[str, Channel] = {} 

100 

101 self._producer: Producer 

102 self._partitions: dict[str, str] 

103 self._registered = False 

104 self._validator = RequestValidator() 

105 

106 def register(self): 

107 assert not self._registered, "has already registered" 

108 

109 self.channels = { 

110 channel.name: channel 

111 for channel in Client(self.initial_url).find(publisher=self.publisher_id) 

112 } 

113 if not self.channels: 

114 msg = f"unknown publisher ID '{self.publisher_id}'." 

115 raise ValueError(msg) 

116 

117 # extract the channel partition map 

118 self._partitions = {} 

119 for channel in self.channels.values(): 

120 if not channel.partition_id: 

121 msg = f"could not determine partition_id for channel {channel}." 

122 raise ValueError(msg) 

123 self._partitions[channel.name] = channel.partition_id 

124 

125 self._registered = True 

126 

127 return self 

128 

129 def enter(self): 

130 if not self._registered: 

131 msg = "must register publisher interface before publishing." 

132 raise RuntimeError(msg) 

133 

134 # get connection properties 

135 descriptor = create_descriptor( 

136 RequestType.Publish, 

137 publisher_id=self.publisher_id, 

138 validator=self._validator, 

139 ) 

140 properties: dict[str, str] = {} 

141 with connect(self.initial_url) as client: 

142 flight_info = client.get_flight_info(descriptor) 

143 with MultiEndpointStream(flight_info.endpoints, client) as stream: 

144 for data in stream.unpack(): 

145 kv_pairs = data["properties"] 

146 properties.update(dict(kv_pairs)) 

147 

148 # set up producer 

149 self._producer = Producer( 

150 { 

151 "message.max.bytes": 10_000_000, # 10 MB 

152 "enable.idempotence": True, 

153 **properties, 

154 } 

155 ) 

156 

157 def __enter__(self) -> Publisher: 

158 self.enter() 

159 return self 

160 

161 def publish( 

162 self, 

163 block: SeriesBlock, 

164 timeout: timedelta = constants.DEFAULT_TIMEOUT, 

165 ) -> None: 

166 """Publish timeseries data 

167 

168 Parameters 

169 ---------- 

170 block : SeriesBlock 

171 A data block with all channels to publish. 

172 timeout : timedelta, optional 

173 The maximum time to wait to publish before timing out. 

174 Default is 2 seconds. 

175 

176 """ 

177 if not hasattr(self, "_producer") or not self._producer: 

178 msg = ( 

179 "publication interface not initialized, " 

180 "please use context manager when publishing." 

181 ) 

182 raise RuntimeError(msg) 

183 

184 for name, channel in block.channels.items(): 

185 if channel != self.channels[name]: 

186 msg = f"invalid channel for this publisher: {channel}" 

187 raise ValueError(msg) 

188 

189 # FIXME: updating partitions should only be allowed for 

190 # special blessed publishers, that are currently not using 

191 # this interface, so we're disabling this functionality for 

192 # the time being, until we have a better way to manage it. 

193 # 

194 # # check for new metadata changes 

195 # changed = set(block.channels.values()) - set(self.channels.values()) 

196 # # exchange to transfer metadata and get new/updated partition IDs 

197 # if changed: 

198 # self._update_partitions(changed) 

199 

200 # publish data for each data type, splitting into 

201 # subblocks based on a maximum channel maximum 

202 for partition_id, batch in block.to_row_batches(self._partitions): 

203 topic = f"arrakis-{partition_id}" 

204 logger.debug("publishing to topic %s: %s", topic, batch) 

205 self._producer.produce(topic=topic, value=serialize_batch(batch)) 

206 self._producer.flush() 

207 

208 def _update_partitions( 

209 self, channels: Iterable[Channel] 

210 ) -> None: # pragma: no cover 

211 # set up flight 

212 assert self._registered, "has not registered yet" 

213 descriptor = create_descriptor( 

214 RequestType.Partition, 

215 publisher_id=self.publisher_id, 

216 validator=self._validator, 

217 ) 

218 # FIXME: should we not get FlightInfo first? 

219 with connect(self.initial_url) as client: 

220 writer, reader = client.do_exchange(descriptor) 

221 

222 # send over list of channels to map new/updated partitions for 

223 dtypes = [channel_to_dtype_name(channel) for channel in channels] 

224 schema = pyarrow.schema( 

225 [ 

226 pyarrow.field("channel", pyarrow.string(), nullable=False), 

227 pyarrow.field("data_type", pyarrow.string(), nullable=False), 

228 pyarrow.field("sample_rate", pyarrow.int32(), nullable=False), 

229 pyarrow.field("partition_id", pyarrow.string()), 

230 ] 

231 ) 

232 batch = pyarrow.RecordBatch.from_arrays( 

233 [ 

234 pyarrow.array( 

235 [str(channel) for channel in channels], 

236 type=schema.field("channel").type, 

237 ), 

238 pyarrow.array(dtypes, type=schema.field("data_type").type), 

239 pyarrow.array( 

240 [channel.sample_rate for channel in channels], 

241 type=schema.field("sample_rate").type, 

242 ), 

243 pyarrow.array( 

244 [None for _ in channels], 

245 type=schema.field("partition_id").type, 

246 ), 

247 ], 

248 schema=schema, 

249 ) 

250 

251 # send over the partitions 

252 writer.begin(schema) 

253 writer.write_batch(batch) 

254 writer.done_writing() 

255 

256 # get back the partition IDs and update 

257 partitions = reader.read_all().to_pydict() 

258 for channel, id_ in zip(partitions["channel"], partitions["partition_id"]): 

259 self._partitions[channel] = id_ 

260 

261 def close(self) -> None: 

262 logger.info("closing kafka producer...") 

263 with contextlib.suppress(Exception): 

264 self._producer.flush() 

265 

266 def __exit__(self, *exc) -> Literal[False]: 

267 self.close() 

268 return False