Coverage for gcsfs/prefetcher.py: 99%
314 statements
« prev ^ index » next coverage.py v7.9.1, created at 2026-04-20 18:41 -0400
« prev ^ index » next coverage.py v7.9.1, created at 2026-04-20 18:41 -0400
1import asyncio
2import ctypes
3import logging
4import threading
5from collections import deque
7import fsspec.asyn
9logger = logging.getLogger(__name__)
11PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize
12PyBytes_FromStringAndSize.restype = ctypes.py_object
13PyBytes_FromStringAndSize.argtypes = [ctypes.c_void_p, ctypes.c_ssize_t]
15PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
16PyBytes_AsString.restype = ctypes.c_void_p
17PyBytes_AsString.argtypes = [ctypes.py_object]
20# Please refer to following discussion to understand why this is required at this point
21# Discussion = https://github.com/fsspec/gcsfs/pull/795#discussion_r3032749881
22def _fast_slice(src_bytes, offset, read_size):
23 if read_size == 0:
24 return b""
25 dest_bytes = PyBytes_FromStringAndSize(None, read_size)
26 src_ptr = PyBytes_AsString(src_bytes)
27 dest_ptr = PyBytes_AsString(dest_bytes)
29 ctypes.memmove(dest_ptr, src_ptr + offset, read_size)
30 return dest_bytes
33class RunningAverageTracker:
34 """Tracks a running average of values over a sliding window.
36 This is used to monitor read sizes and adaptively scale the
37 prefetching strategy based on recent user behavior.
38 """
40 def __init__(self, maxlen=10):
41 """Initializes the tracker with a specific window size.
43 Args:
44 maxlen (int): The maximum number of historical values to keep.
45 """
46 logger.debug("Initializing RunningAverageTracker with maxlen: %d", maxlen)
47 self._history = deque(maxlen=maxlen)
48 self._sum = 0
50 def add(self, value: int):
51 """Adds a new value to the sliding window and updates the rolling sum.
53 Args:
54 value (int): The integer value to add to the history.
55 """
56 if value <= 0:
57 raise ValueError(
58 "Internal error, RunningAverageTracker tried inserting negative value"
59 )
60 if len(self._history) == self._history.maxlen:
61 self._sum -= self._history[0]
63 self._history.append(value)
64 self._sum += value
65 logger.debug(
66 "RunningAverageTracker added value: %d, new sum: %d", value, self._sum
67 )
69 @property
70 def average(self) -> int:
71 """Calculates and returns the current running average.
73 Returns:
74 int: The integer average of the current history.
75 """
76 count = len(self._history)
77 if count == 0:
78 return 1024 * 1024 # 1MB
79 return self._sum // count
81 def clear(self):
82 """Clears the history and resets the sum to zero."""
83 logger.debug("Clearing RunningAverageTracker history.")
84 self._history.clear()
85 self._sum = 0
88class PrefetchProducer:
89 """Background worker that fetches sequential blocks of data.
91 This class handles the network requests. It spawns asynchronous tasks
92 to fetch data ahead of the user's current reading position and
93 places those task promises into a queue for the consumer.
94 """
96 # If the request is too small, and prefetch window is expanded till 5MB
97 # we then make request in 5MB blocks.
98 MIN_CHUNK_SIZE = 5 * 1024 * 1024
100 # If user doesn't specify any max_prefetch_size, the prefetcher defaults
101 # to maximum of 2 * io_size and 128MB
102 MIN_PREFETCH_SIZE = 128 * 1024 * 1024
104 def __init__(
105 self,
106 fetcher,
107 size: int,
108 concurrency: int,
109 queue: asyncio.Queue,
110 wakeup_event: asyncio.Event,
111 get_user_offset,
112 get_io_size,
113 get_sequential_streak,
114 on_error,
115 user_max_prefetch_size=None,
116 ):
117 """Initializes the background producer.
119 Args:
120 fetcher (Callable): A coroutine function to fetch bytes from a remote source.
121 size (int): Total size of the file being fetched.
122 concurrency (int): Maximum number of concurrent fetch tasks.
123 queue (asyncio.Queue): The shared queue to push download tasks into.
124 wakeup_event (asyncio.Event): Event used to wake the producer from an idle state.
125 get_user_offset (Callable): Function returning the user's current read offset.
126 get_io_size (Callable): Function returning the adaptive IO size.
127 get_sequential_streak (Callable): Function returning the current sequential read streak.
128 on_error (Callable): Callback triggered when a background error occurs.
129 user_max_prefetch_size (int, optional): A hard limit for prefetch size overrides.
130 """
131 logger.debug(
132 "Initializing PrefetchProducer: size=%d, concurrency=%d, user_max_prefetch_size=%s",
133 size,
134 concurrency,
135 user_max_prefetch_size,
136 )
137 self.fetcher = fetcher
138 self.size = size
139 self.concurrency = concurrency
140 self.queue = queue
141 self.wakeup_event = wakeup_event
143 self.get_user_offset = get_user_offset
144 self.get_io_size = get_io_size
145 self.get_sequential_streak = get_sequential_streak
146 self.on_error = on_error
147 self._user_max_prefetch_size = user_max_prefetch_size
149 self.current_offset = 0
150 self.is_stopped = False
151 self._active_tasks = set()
152 self._producer_task = None
154 @property
155 def max_prefetch_size(self) -> int:
156 """Calculates the maximum prefetch size based on user intent or io size.
158 Returns:
159 int: The maximum number of bytes to prefetch ahead.
160 """
161 if self._user_max_prefetch_size is not None:
162 return min(
163 self._user_max_prefetch_size,
164 max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE),
165 )
166 return max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE)
168 def start(self):
169 """Starts the background producer loop.
171 This clears any previous wakeup events and spawns the main loop task.
172 """
173 logger.debug("Starting PrefetchProducer loop.")
174 self.is_stopped = False
175 self.wakeup_event.clear()
176 self._producer_task = asyncio.create_task(self._loop())
178 async def stop(self):
179 """Cancels all active fetch tasks and shuts down the producer loop.
181 This method ensures the queue is flushed and waits for cancelled
182 tasks to finish cleaning up.
183 """
184 logger.debug(
185 "Stopping PrefetchProducer. Active fetch tasks: %d", len(self._active_tasks)
186 )
187 self.is_stopped = True
188 self.wakeup_event.set()
190 tasks_to_wait = []
191 if self._producer_task and not self._producer_task.done():
192 self._producer_task.cancel()
193 tasks_to_wait.append(self._producer_task)
195 for task in list(self._active_tasks):
196 if not task.done():
197 tasks_to_wait.append(task)
198 self._active_tasks.clear()
200 # Clear out any leftover items in the queue
201 cleared_items = 0
202 while not self.queue.empty():
203 try:
204 item = self.queue.get_nowait()
205 if (
206 isinstance(item, asyncio.Task)
207 and item.done()
208 and not item.cancelled()
209 ):
210 item.exception()
211 cleared_items += 1
212 except asyncio.QueueEmpty:
213 break
215 if cleared_items > 0:
216 logger.debug(
217 "Cleared %d leftover items from the queue during stop.", cleared_items
218 )
220 if tasks_to_wait:
221 logger.debug(
222 "Waiting for %d cancelled tasks to finish their teardown.",
223 len(tasks_to_wait),
224 )
225 await asyncio.gather(*tasks_to_wait, return_exceptions=True)
227 async def restart(self, new_offset: int):
228 """Stops current tasks and restarts the background loop at a new byte offset.
230 Args:
231 new_offset (int): The new byte position to start prefetching from.
232 """
233 logger.debug("Restarting PrefetchProducer at new offset: %d", new_offset)
234 await self.stop()
235 self.current_offset = new_offset
236 self.start()
238 async def _loop(self):
239 """The main background loop that calculates sizes and spawns fetch tasks."""
240 logger.debug("PrefetchProducer internal loop is now running.")
241 try:
242 while not self.is_stopped:
243 await self.wakeup_event.wait()
244 self.wakeup_event.clear()
246 if self.is_stopped:
247 break
249 io_size = self.get_io_size()
250 streak = self.get_sequential_streak()
251 prefetch_size = min((streak + 1) * io_size, self.max_prefetch_size)
253 logger.debug(
254 "Producer awake. Current offset: %d, User offset: %d, Prefetch size: %d",
255 self.current_offset,
256 self.get_user_offset(),
257 prefetch_size,
258 )
260 while (
261 not self.is_stopped
262 and (self.current_offset - self.get_user_offset()) < prefetch_size
263 and self.current_offset < self.size
264 ):
265 user_offset = self.get_user_offset()
266 space_remaining = self.size - self.current_offset
267 prefetch_space_available = prefetch_size - (
268 self.current_offset - user_offset
269 )
271 if prefetch_size >= self.MIN_CHUNK_SIZE:
272 if prefetch_space_available >= self.MIN_CHUNK_SIZE:
273 actual_size = min(
274 max(self.MIN_CHUNK_SIZE, io_size), space_remaining
275 )
276 else:
277 break
278 else:
279 actual_size = min(io_size, space_remaining)
281 if streak < 2:
282 sfactor = self.concurrency
283 else:
284 sfactor = min(
285 self.concurrency,
286 max(
287 1,
288 actual_size * self.concurrency // prefetch_size,
289 ),
290 )
292 logger.debug(
293 "Spawning fetch task. Offset: %d, Size: %d, Split Factor: %d",
294 self.current_offset,
295 actual_size,
296 sfactor,
297 )
299 download_task = asyncio.create_task(
300 self.fetcher(
301 self.current_offset, actual_size, split_factor=sfactor
302 )
303 )
304 self._active_tasks.add(download_task)
305 download_task.add_done_callback(self._active_tasks.discard)
307 await self.queue.put(download_task)
308 self.current_offset += actual_size
310 except asyncio.CancelledError:
311 logger.debug("PrefetchProducer loop was cancelled.")
312 pass
313 except Exception as e:
314 logger.error(
315 "PrefetchProducer loop encountered an unexpected error: %s",
316 e,
317 exc_info=True,
318 )
319 self.is_stopped = True
320 self.on_error(e)
321 await self.queue.put(e)
324class PrefetchConsumer:
325 """Consumes prefetched chunks from the queue and manages byte slicing.
327 This class pulls data out of the shared queue and slices it into the
328 exact byte sizes requested by the user. It also manages the local block buffer.
329 """
331 def __init__(
332 self,
333 queue: asyncio.Queue,
334 wakeup_event: asyncio.Event,
335 is_producer_stopped,
336 on_error,
337 ):
338 """Initializes the consumer.
340 Args:
341 queue (asyncio.Queue): The shared queue containing fetch tasks.
342 wakeup_event (asyncio.Event): Event used to wake the producer when more data is needed.
343 is_producer_stopped (Callable): Function returning whether the producer has been halted.
344 on_error (Callable): Callback triggered when a fetch error is encountered.
345 """
346 logger.debug("Initializing PrefetchConsumer.")
347 self.queue = queue
348 self.wakeup_event = wakeup_event
349 self.is_producer_stopped = is_producer_stopped
350 self.on_error = on_error
351 self.sequential_streak = 0
352 self.offset = 0
353 self._current_block = b""
354 self._current_block_idx = 0
356 def seek(self, new_offset: int):
357 """Clears the buffer and resets the internal offset for a hard seek.
359 Args:
360 new_offset (int): The byte position the consumer is jumping to.
361 """
362 logger.debug(
363 "Consumer executing hard seek to offset %d. Clearing internal buffer.",
364 new_offset,
365 )
366 self.offset = new_offset
367 self.sequential_streak = 0
368 self._current_block = b""
369 self._current_block_idx = 0
371 def clear_buffer(self):
372 """Discards the local byte buffer. Useful during shutdown or resets."""
373 logger.debug("Consumer local block buffer cleared.")
374 self._current_block = b""
375 self._current_block_idx = 0
377 async def _advance(self, size: int, save_data: bool) -> list[bytes]:
378 """Internal method to advance the offset and optionally extract data.
380 Handles queue exhaustion, producer wakeups, and streak tracking.
381 """
382 if size <= 0:
383 return []
385 chunks = []
386 processed = 0
388 while processed < size:
389 available = len(self._current_block) - self._current_block_idx
391 if not available:
392 if self.is_producer_stopped() and self.queue.empty():
393 logger.debug("Consumer reached EOF.")
394 break
396 if self.queue.empty():
397 logger.debug("Queue is empty. Waking up producer.")
398 self.wakeup_event.set()
400 task = await self.queue.get()
402 if isinstance(task, Exception):
403 logger.error("Consumer retrieved an exception: %s", task)
404 self.on_error(task)
405 raise task
407 try:
408 block = await task
410 self.sequential_streak += 1
411 if self.sequential_streak >= 2:
412 self.wakeup_event.set()
414 self._current_block = block
415 self._current_block_idx = 0
416 available = len(self._current_block)
417 except asyncio.CancelledError:
418 raise
419 except Exception as e:
420 logger.error("Consumer caught an error: %s", e, exc_info=True)
421 self.on_error(e)
422 raise e
424 if not self._current_block:
425 break
427 needed = size - processed
428 take = min(needed, available)
430 if save_data:
431 if take == len(self._current_block) and self._current_block_idx == 0:
432 chunk = self._current_block
433 else:
434 # Native Python slicing was GIL bound in my experiments.
435 chunk = await asyncio.to_thread(
436 _fast_slice, self._current_block, self._current_block_idx, take
437 )
438 chunks.append(chunk)
440 self._current_block_idx += take
441 processed += take
442 self.offset += take
444 return chunks
446 async def consume(self, size: int) -> bytes:
447 """Pulls exactly 'size' bytes from the local block or the task queue.
449 If the local block is exhausted, this will wait on the queue for the next
450 available chunk of data.
452 Args:
453 size (int): The exact number of bytes to retrieve.
455 Returns:
456 bytes: The requested bytes. This may be shorter than 'size' if EOF is reached.
458 Raises:
459 Exception: Re-raises any exceptions encountered by the producer fetch tasks.
460 """
461 if size <= 0:
462 return b""
464 chunks = await self._advance(size, save_data=True)
466 if not chunks:
467 return b""
469 if len(chunks) == 1:
470 return chunks[0]
472 return await asyncio.to_thread(b"".join, chunks)
474 async def skip(self, size: int) -> None:
475 """Advances the consumer offset without allocating memory."""
476 await self._advance(size, save_data=False)
479class BackgroundPrefetcher:
480 """Orchestrator that manages reading behavior and coordinates background work.
482 This acts as the main public interface for the file reader. It tracks the
483 user's reading history, routes seek operations, and links the producer's
484 network tasks with the consumer's data slicing logic.
485 """
487 def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None):
488 """Initializes the background prefetcher.
490 Args:
491 fetcher (Callable): A coroutine of the form `f(start, end)` which gets bytes from the remote.
492 size (int): Total byte size of the file being read.
493 concurrency (int): Number of concurrent network requests to use for large chunks.
494 max_prefetch_size (int, optional): Maximum bytes to prefetch ahead of the current user offset.
496 Raises:
497 ValueError: If max_prefetch_size is provided but is not a positive integer.
498 """
499 logger.debug(
500 "Starting BackgroundPrefetcher. Size: %d, Concurrency: %d, Max Prefetch: %s",
501 size,
502 concurrency,
503 max_prefetch_size,
504 )
505 self.size = size
506 self.concurrency = concurrency
508 if max_prefetch_size is not None and max_prefetch_size <= 0:
509 logger.error("Invalid max_prefetch_size provided: %s", max_prefetch_size)
510 raise ValueError(
511 "max_prefetch_size should be a positive integer to use adaptive prefetching!"
512 )
514 self.loop = fsspec.asyn.get_loop()
515 self._lock = threading.Lock()
516 self._error = None
517 self.is_stopped = False
518 self.queue = asyncio.Queue()
519 self.wakeup_event = asyncio.Event()
520 self.user_offset = 0
521 self.read_tracker = RunningAverageTracker(maxlen=10)
523 self.consumer = PrefetchConsumer(
524 queue=self.queue,
525 wakeup_event=self.wakeup_event,
526 is_producer_stopped=self._is_producer_stopped,
527 on_error=self._set_error,
528 )
530 self.producer = PrefetchProducer(
531 fetcher=fetcher,
532 size=self.size,
533 concurrency=self.concurrency,
534 queue=self.queue,
535 wakeup_event=self.wakeup_event,
536 get_user_offset=lambda: self.consumer.offset,
537 get_io_size=self._get_adaptive_io_size,
538 get_sequential_streak=lambda: self.consumer.sequential_streak,
539 on_error=self._set_error,
540 user_max_prefetch_size=max_prefetch_size,
541 )
543 async def _start():
544 self.producer.start()
546 fsspec.asyn.sync(self.loop, _start)
547 logger.debug("BackgroundPrefetcher initialization complete.")
549 def __enter__(self):
550 """Context manager entry point."""
551 return self
553 def __exit__(self, exc_type, exc_val, exc_tb):
554 """Context manager exit point. Ensures the prefetcher is cleanly closed."""
555 self.close()
557 def _get_adaptive_io_size(self) -> int:
558 return self.read_tracker.average
560 def _is_producer_stopped(self) -> bool:
561 return self.producer.is_stopped if hasattr(self, "producer") else True
563 def _set_error(self, e: Exception):
564 logger.error("Global error state set in BackgroundPrefetcher: %s", e)
565 self._error = e
567 async def _restart_producer(self, new_offset: int):
568 logger.debug(
569 "Handling seek request. Restarting producer at offset: %d", new_offset
570 )
571 self._error = None
572 await self.producer.restart(new_offset)
573 self.consumer.seek(new_offset)
574 self.read_tracker.clear()
576 async def _async_fetch(self, start, end):
577 logger.debug("Executing _async_fetch for range %d - %d.", start, end)
579 if start != self.user_offset:
580 if self.user_offset < start <= self.producer.current_offset:
581 logger.debug(
582 "Soft seek detected. Skipping ahead from %d to %d.",
583 self.user_offset,
584 start,
585 )
586 skip_amount = start - self.user_offset
587 await self.consumer.skip(skip_amount)
588 self.user_offset = start
589 else:
590 logger.debug(
591 "Hard seek detected. Moving user offset from %d to %d.",
592 self.user_offset,
593 start,
594 )
595 self.user_offset = start
596 await self._restart_producer(start)
598 requested_size = end - start
599 self.read_tracker.add(requested_size)
601 chunk = await self.consumer.consume(requested_size)
602 self.user_offset += len(chunk)
604 logger.debug("Completed _async_fetch. Returned %d bytes.", len(chunk))
605 return chunk
607 def _fetch(self, start: int | None, end: int | None) -> bytes:
608 if start is None:
609 start = 0
610 if end is None:
611 end = self.size
613 end = min(end, self.size)
614 logger.debug(
615 "Synchronous _fetch called for bounds start=%s, end=%s.", start, end
616 )
618 if start >= self.size or start >= end:
619 logger.warning(
620 "Invalid bounds or EOF reached in _fetch. Start: %d, End: %d, Size: %d",
621 start,
622 end,
623 self.size,
624 )
625 return b""
627 with self._lock:
628 if self._error:
629 logger.error("Cannot fetch data: instance has an active error state.")
630 raise self._error
632 if self.is_stopped:
633 logger.error(
634 "Cannot fetch data: BackgroundPrefetcher is stopped or closed."
635 )
636 raise RuntimeError(
637 "The file instance has been closed. This can occur if a close operation "
638 "is executed concurrently while a read operation is still in progress."
639 )
641 try:
642 result = fsspec.asyn.sync(self.loop, self._async_fetch, start, end)
643 except Exception as e:
644 logger.error(
645 "Exception raised during synchronous fetch: %s", e, exc_info=True
646 )
647 self.is_stopped = True
648 self._error = e
649 raise
651 if self.is_stopped:
652 logger.error("Instance was stopped mid-fetch operation.")
653 raise RuntimeError(
654 "The file instance has been closed. This can occur if a close operation "
655 "is executed concurrently while a read operation is still in progress."
656 )
658 return result
660 def close(self):
661 """Safely shuts down the prefetcher.
663 This cancels all background network tasks and blocks until everything
664 is completely cleaned up. It also clears the internal consumer buffer.
665 """
666 logger.debug("Closing BackgroundPrefetcher and cleaning up resources.")
667 if self.is_stopped:
668 logger.debug(
669 "BackgroundPrefetcher is already stopped. Skipping close operation."
670 )
671 return
673 self.is_stopped = True
674 with self._lock:
675 fsspec.asyn.sync(self.loop, self.producer.stop)
676 self.consumer.clear_buffer()
677 logger.debug("BackgroundPrefetcher closed successfully.")