Coverage for arrakis/publish.py: 88.5%
87 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-13 15:09 -0700
1# Copyright (c) 2022, California Institute of Technology and contributors
2#
3# You should have received a copy of the licensing terms for this
4# software included in the file "LICENSE" located in the top-level
5# directory of this package. If you did not, you can view a copy at
6# https://git.ligo.org/ngdd/arrakis-python/-/raw/main/LICENSE
8"""Publisher API."""
10from __future__ import annotations
12import contextlib
13import logging
14from typing import TYPE_CHECKING, Literal
16import pyarrow
17from pyarrow.flight import connect
19from . import constants
20from .client import Client
21from .flight import (
22 MultiEndpointStream,
23 RequestType,
24 RequestValidator,
25 create_descriptor,
26 parse_url,
27)
29try:
30 from confluent_kafka import Producer
31except ImportError:
32 HAS_KAFKA = False
33else:
34 HAS_KAFKA = True
36if TYPE_CHECKING:
37 from collections.abc import Iterable
38 from datetime import timedelta
40 from .block import SeriesBlock
41 from .channel import Channel
44logger = logging.getLogger("arrakis")
47def serialize_batch(batch: pyarrow.RecordBatch):
48 """Serialize a record batch to bytes.
50 Parameters
51 ----------
52 batch : pyarrow.RecordBatch
53 The batch to serialize.
55 Returns
56 -------
57 bytes
58 The serialized buffer.
60 """
61 sink = pyarrow.BufferOutputStream()
62 with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
63 writer.write_batch(batch)
64 return sink.getvalue()
67def channel_to_dtype_name(channel: Channel) -> str:
68 """Given a channel, return the data type's name."""
69 assert channel.data_type is not None
70 if isinstance(channel.data_type, str):
71 return channel.data_type
72 return channel.data_type.name
75class Publisher:
76 """Publisher to publish timeseries to Arrakis service.
78 Parameters
79 ----------
80 id : str
81 Publisher ID string.
82 url : str
83 Initial Flight URL to connect to.
85 """
87 def __init__(self, publisher_id: str, url: str | None = None):
88 if not HAS_KAFKA:
89 msg = (
90 "Publishing requires confluent-kafka to be installed."
91 "This is provided by the 'publish' extra or it can be "
92 "installed manually through pip or conda."
93 )
94 raise ImportError(msg)
96 self.publisher_id = publisher_id
97 self.initial_url = parse_url(url)
99 self.channels: dict[str, Channel] = {}
101 self._producer: Producer
102 self._partitions: dict[str, str]
103 self._registered = False
104 self._validator = RequestValidator()
106 def register(self):
107 assert not self._registered, "has already registered"
109 self.channels = {
110 channel.name: channel
111 for channel in Client(self.initial_url).find(publisher=self.publisher_id)
112 }
113 if not self.channels:
114 msg = f"unknown publisher ID '{self.publisher_id}'."
115 raise ValueError(msg)
117 # extract the channel partition map
118 self._partitions = {}
119 for channel in self.channels.values():
120 if not channel.partition_id:
121 msg = f"could not determine partition_id for channel {channel}."
122 raise ValueError(msg)
123 self._partitions[channel.name] = channel.partition_id
125 self._registered = True
127 return self
129 def enter(self):
130 if not self._registered:
131 msg = "must register publisher interface before publishing."
132 raise RuntimeError(msg)
134 # get connection properties
135 descriptor = create_descriptor(
136 RequestType.Publish,
137 publisher_id=self.publisher_id,
138 validator=self._validator,
139 )
140 properties: dict[str, str] = {}
141 with connect(self.initial_url) as client:
142 flight_info = client.get_flight_info(descriptor)
143 with MultiEndpointStream(flight_info.endpoints, client) as stream:
144 for data in stream.unpack():
145 kv_pairs = data["properties"]
146 properties.update(dict(kv_pairs))
148 # set up producer
149 self._producer = Producer(
150 {
151 "message.max.bytes": 10_000_000, # 10 MB
152 "enable.idempotence": True,
153 **properties,
154 }
155 )
157 def __enter__(self) -> Publisher:
158 self.enter()
159 return self
161 def publish(
162 self,
163 block: SeriesBlock,
164 timeout: timedelta = constants.DEFAULT_TIMEOUT,
165 ) -> None:
166 """Publish timeseries data
168 Parameters
169 ----------
170 block : SeriesBlock
171 A data block with all channels to publish.
172 timeout : timedelta, optional
173 The maximum time to wait to publish before timing out.
174 Default is 2 seconds.
176 """
177 if not hasattr(self, "_producer") or not self._producer:
178 msg = (
179 "publication interface not initialized, "
180 "please use context manager when publishing."
181 )
182 raise RuntimeError(msg)
184 for name, channel in block.channels.items():
185 if channel != self.channels[name]:
186 msg = f"invalid channel for this publisher: {channel}"
187 raise ValueError(msg)
189 # FIXME: updating partitions should only be allowed for
190 # special blessed publishers, that are currently not using
191 # this interface, so we're disabling this functionality for
192 # the time being, until we have a better way to manage it.
193 #
194 # # check for new metadata changes
195 # changed = set(block.channels.values()) - set(self.channels.values())
196 # # exchange to transfer metadata and get new/updated partition IDs
197 # if changed:
198 # self._update_partitions(changed)
200 # publish data for each data type, splitting into
201 # subblocks based on a maximum channel maximum
202 for partition_id, batch in block.to_row_batches(self._partitions):
203 topic = f"arrakis-{partition_id}"
204 logger.debug("publishing to topic %s: %s", topic, batch)
205 self._producer.produce(topic=topic, value=serialize_batch(batch))
206 self._producer.flush()
208 def _update_partitions(
209 self, channels: Iterable[Channel]
210 ) -> None: # pragma: no cover
211 # set up flight
212 assert self._registered, "has not registered yet"
213 descriptor = create_descriptor(
214 RequestType.Partition,
215 publisher_id=self.publisher_id,
216 validator=self._validator,
217 )
218 # FIXME: should we not get FlightInfo first?
219 with connect(self.initial_url) as client:
220 writer, reader = client.do_exchange(descriptor)
222 # send over list of channels to map new/updated partitions for
223 dtypes = [channel_to_dtype_name(channel) for channel in channels]
224 schema = pyarrow.schema(
225 [
226 pyarrow.field("channel", pyarrow.string(), nullable=False),
227 pyarrow.field("data_type", pyarrow.string(), nullable=False),
228 pyarrow.field("sample_rate", pyarrow.int32(), nullable=False),
229 pyarrow.field("partition_id", pyarrow.string()),
230 ]
231 )
232 batch = pyarrow.RecordBatch.from_arrays(
233 [
234 pyarrow.array(
235 [str(channel) for channel in channels],
236 type=schema.field("channel").type,
237 ),
238 pyarrow.array(dtypes, type=schema.field("data_type").type),
239 pyarrow.array(
240 [channel.sample_rate for channel in channels],
241 type=schema.field("sample_rate").type,
242 ),
243 pyarrow.array(
244 [None for _ in channels],
245 type=schema.field("partition_id").type,
246 ),
247 ],
248 schema=schema,
249 )
251 # send over the partitions
252 writer.begin(schema)
253 writer.write_batch(batch)
254 writer.done_writing()
256 # get back the partition IDs and update
257 partitions = reader.read_all().to_pydict()
258 for channel, id_ in zip(partitions["channel"], partitions["partition_id"]):
259 self._partitions[channel] = id_
261 def close(self) -> None:
262 logger.info("closing kafka producer...")
263 with contextlib.suppress(Exception):
264 self._producer.flush()
266 def __exit__(self, *exc) -> Literal[False]:
267 self.close()
268 return False