Coverage for gcsfs/prefetcher.py: 99%

314 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2026-04-20 18:41 -0400

1import asyncio 

2import ctypes 

3import logging 

4import threading 

5from collections import deque 

6 

7import fsspec.asyn 

8 

9logger = logging.getLogger(__name__) 

10 

11PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize 

12PyBytes_FromStringAndSize.restype = ctypes.py_object 

13PyBytes_FromStringAndSize.argtypes = [ctypes.c_void_p, ctypes.c_ssize_t] 

14 

15PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString 

16PyBytes_AsString.restype = ctypes.c_void_p 

17PyBytes_AsString.argtypes = [ctypes.py_object] 

18 

19 

20# Please refer to following discussion to understand why this is required at this point 

21# Discussion = https://github.com/fsspec/gcsfs/pull/795#discussion_r3032749881 

22def _fast_slice(src_bytes, offset, read_size): 

23 if read_size == 0: 

24 return b"" 

25 dest_bytes = PyBytes_FromStringAndSize(None, read_size) 

26 src_ptr = PyBytes_AsString(src_bytes) 

27 dest_ptr = PyBytes_AsString(dest_bytes) 

28 

29 ctypes.memmove(dest_ptr, src_ptr + offset, read_size) 

30 return dest_bytes 

31 

32 

33class RunningAverageTracker: 

34 """Tracks a running average of values over a sliding window. 

35 

36 This is used to monitor read sizes and adaptively scale the 

37 prefetching strategy based on recent user behavior. 

38 """ 

39 

40 def __init__(self, maxlen=10): 

41 """Initializes the tracker with a specific window size. 

42 

43 Args: 

44 maxlen (int): The maximum number of historical values to keep. 

45 """ 

46 logger.debug("Initializing RunningAverageTracker with maxlen: %d", maxlen) 

47 self._history = deque(maxlen=maxlen) 

48 self._sum = 0 

49 

50 def add(self, value: int): 

51 """Adds a new value to the sliding window and updates the rolling sum. 

52 

53 Args: 

54 value (int): The integer value to add to the history. 

55 """ 

56 if value <= 0: 

57 raise ValueError( 

58 "Internal error, RunningAverageTracker tried inserting negative value" 

59 ) 

60 if len(self._history) == self._history.maxlen: 

61 self._sum -= self._history[0] 

62 

63 self._history.append(value) 

64 self._sum += value 

65 logger.debug( 

66 "RunningAverageTracker added value: %d, new sum: %d", value, self._sum 

67 ) 

68 

69 @property 

70 def average(self) -> int: 

71 """Calculates and returns the current running average. 

72 

73 Returns: 

74 int: The integer average of the current history. 

75 """ 

76 count = len(self._history) 

77 if count == 0: 

78 return 1024 * 1024 # 1MB 

79 return self._sum // count 

80 

81 def clear(self): 

82 """Clears the history and resets the sum to zero.""" 

83 logger.debug("Clearing RunningAverageTracker history.") 

84 self._history.clear() 

85 self._sum = 0 

86 

87 

88class PrefetchProducer: 

89 """Background worker that fetches sequential blocks of data. 

90 

91 This class handles the network requests. It spawns asynchronous tasks 

92 to fetch data ahead of the user's current reading position and 

93 places those task promises into a queue for the consumer. 

94 """ 

95 

96 # If the request is too small, and prefetch window is expanded till 5MB 

97 # we then make request in 5MB blocks. 

98 MIN_CHUNK_SIZE = 5 * 1024 * 1024 

99 

100 # If user doesn't specify any max_prefetch_size, the prefetcher defaults 

101 # to maximum of 2 * io_size and 128MB 

102 MIN_PREFETCH_SIZE = 128 * 1024 * 1024 

103 

104 def __init__( 

105 self, 

106 fetcher, 

107 size: int, 

108 concurrency: int, 

109 queue: asyncio.Queue, 

110 wakeup_event: asyncio.Event, 

111 get_user_offset, 

112 get_io_size, 

113 get_sequential_streak, 

114 on_error, 

115 user_max_prefetch_size=None, 

116 ): 

117 """Initializes the background producer. 

118 

119 Args: 

120 fetcher (Callable): A coroutine function to fetch bytes from a remote source. 

121 size (int): Total size of the file being fetched. 

122 concurrency (int): Maximum number of concurrent fetch tasks. 

123 queue (asyncio.Queue): The shared queue to push download tasks into. 

124 wakeup_event (asyncio.Event): Event used to wake the producer from an idle state. 

125 get_user_offset (Callable): Function returning the user's current read offset. 

126 get_io_size (Callable): Function returning the adaptive IO size. 

127 get_sequential_streak (Callable): Function returning the current sequential read streak. 

128 on_error (Callable): Callback triggered when a background error occurs. 

129 user_max_prefetch_size (int, optional): A hard limit for prefetch size overrides. 

130 """ 

131 logger.debug( 

132 "Initializing PrefetchProducer: size=%d, concurrency=%d, user_max_prefetch_size=%s", 

133 size, 

134 concurrency, 

135 user_max_prefetch_size, 

136 ) 

137 self.fetcher = fetcher 

138 self.size = size 

139 self.concurrency = concurrency 

140 self.queue = queue 

141 self.wakeup_event = wakeup_event 

142 

143 self.get_user_offset = get_user_offset 

144 self.get_io_size = get_io_size 

145 self.get_sequential_streak = get_sequential_streak 

146 self.on_error = on_error 

147 self._user_max_prefetch_size = user_max_prefetch_size 

148 

149 self.current_offset = 0 

150 self.is_stopped = False 

151 self._active_tasks = set() 

152 self._producer_task = None 

153 

154 @property 

155 def max_prefetch_size(self) -> int: 

156 """Calculates the maximum prefetch size based on user intent or io size. 

157 

158 Returns: 

159 int: The maximum number of bytes to prefetch ahead. 

160 """ 

161 if self._user_max_prefetch_size is not None: 

162 return min( 

163 self._user_max_prefetch_size, 

164 max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE), 

165 ) 

166 return max(2 * self.get_io_size(), self.MIN_PREFETCH_SIZE) 

167 

168 def start(self): 

169 """Starts the background producer loop. 

170 

171 This clears any previous wakeup events and spawns the main loop task. 

172 """ 

173 logger.debug("Starting PrefetchProducer loop.") 

174 self.is_stopped = False 

175 self.wakeup_event.clear() 

176 self._producer_task = asyncio.create_task(self._loop()) 

177 

178 async def stop(self): 

179 """Cancels all active fetch tasks and shuts down the producer loop. 

180 

181 This method ensures the queue is flushed and waits for cancelled 

182 tasks to finish cleaning up. 

183 """ 

184 logger.debug( 

185 "Stopping PrefetchProducer. Active fetch tasks: %d", len(self._active_tasks) 

186 ) 

187 self.is_stopped = True 

188 self.wakeup_event.set() 

189 

190 tasks_to_wait = [] 

191 if self._producer_task and not self._producer_task.done(): 

192 self._producer_task.cancel() 

193 tasks_to_wait.append(self._producer_task) 

194 

195 for task in list(self._active_tasks): 

196 if not task.done(): 

197 tasks_to_wait.append(task) 

198 self._active_tasks.clear() 

199 

200 # Clear out any leftover items in the queue 

201 cleared_items = 0 

202 while not self.queue.empty(): 

203 try: 

204 item = self.queue.get_nowait() 

205 if ( 

206 isinstance(item, asyncio.Task) 

207 and item.done() 

208 and not item.cancelled() 

209 ): 

210 item.exception() 

211 cleared_items += 1 

212 except asyncio.QueueEmpty: 

213 break 

214 

215 if cleared_items > 0: 

216 logger.debug( 

217 "Cleared %d leftover items from the queue during stop.", cleared_items 

218 ) 

219 

220 if tasks_to_wait: 

221 logger.debug( 

222 "Waiting for %d cancelled tasks to finish their teardown.", 

223 len(tasks_to_wait), 

224 ) 

225 await asyncio.gather(*tasks_to_wait, return_exceptions=True) 

226 

227 async def restart(self, new_offset: int): 

228 """Stops current tasks and restarts the background loop at a new byte offset. 

229 

230 Args: 

231 new_offset (int): The new byte position to start prefetching from. 

232 """ 

233 logger.debug("Restarting PrefetchProducer at new offset: %d", new_offset) 

234 await self.stop() 

235 self.current_offset = new_offset 

236 self.start() 

237 

238 async def _loop(self): 

239 """The main background loop that calculates sizes and spawns fetch tasks.""" 

240 logger.debug("PrefetchProducer internal loop is now running.") 

241 try: 

242 while not self.is_stopped: 

243 await self.wakeup_event.wait() 

244 self.wakeup_event.clear() 

245 

246 if self.is_stopped: 

247 break 

248 

249 io_size = self.get_io_size() 

250 streak = self.get_sequential_streak() 

251 prefetch_size = min((streak + 1) * io_size, self.max_prefetch_size) 

252 

253 logger.debug( 

254 "Producer awake. Current offset: %d, User offset: %d, Prefetch size: %d", 

255 self.current_offset, 

256 self.get_user_offset(), 

257 prefetch_size, 

258 ) 

259 

260 while ( 

261 not self.is_stopped 

262 and (self.current_offset - self.get_user_offset()) < prefetch_size 

263 and self.current_offset < self.size 

264 ): 

265 user_offset = self.get_user_offset() 

266 space_remaining = self.size - self.current_offset 

267 prefetch_space_available = prefetch_size - ( 

268 self.current_offset - user_offset 

269 ) 

270 

271 if prefetch_size >= self.MIN_CHUNK_SIZE: 

272 if prefetch_space_available >= self.MIN_CHUNK_SIZE: 

273 actual_size = min( 

274 max(self.MIN_CHUNK_SIZE, io_size), space_remaining 

275 ) 

276 else: 

277 break 

278 else: 

279 actual_size = min(io_size, space_remaining) 

280 

281 if streak < 2: 

282 sfactor = self.concurrency 

283 else: 

284 sfactor = min( 

285 self.concurrency, 

286 max( 

287 1, 

288 actual_size * self.concurrency // prefetch_size, 

289 ), 

290 ) 

291 

292 logger.debug( 

293 "Spawning fetch task. Offset: %d, Size: %d, Split Factor: %d", 

294 self.current_offset, 

295 actual_size, 

296 sfactor, 

297 ) 

298 

299 download_task = asyncio.create_task( 

300 self.fetcher( 

301 self.current_offset, actual_size, split_factor=sfactor 

302 ) 

303 ) 

304 self._active_tasks.add(download_task) 

305 download_task.add_done_callback(self._active_tasks.discard) 

306 

307 await self.queue.put(download_task) 

308 self.current_offset += actual_size 

309 

310 except asyncio.CancelledError: 

311 logger.debug("PrefetchProducer loop was cancelled.") 

312 pass 

313 except Exception as e: 

314 logger.error( 

315 "PrefetchProducer loop encountered an unexpected error: %s", 

316 e, 

317 exc_info=True, 

318 ) 

319 self.is_stopped = True 

320 self.on_error(e) 

321 await self.queue.put(e) 

322 

323 

324class PrefetchConsumer: 

325 """Consumes prefetched chunks from the queue and manages byte slicing. 

326 

327 This class pulls data out of the shared queue and slices it into the 

328 exact byte sizes requested by the user. It also manages the local block buffer. 

329 """ 

330 

331 def __init__( 

332 self, 

333 queue: asyncio.Queue, 

334 wakeup_event: asyncio.Event, 

335 is_producer_stopped, 

336 on_error, 

337 ): 

338 """Initializes the consumer. 

339 

340 Args: 

341 queue (asyncio.Queue): The shared queue containing fetch tasks. 

342 wakeup_event (asyncio.Event): Event used to wake the producer when more data is needed. 

343 is_producer_stopped (Callable): Function returning whether the producer has been halted. 

344 on_error (Callable): Callback triggered when a fetch error is encountered. 

345 """ 

346 logger.debug("Initializing PrefetchConsumer.") 

347 self.queue = queue 

348 self.wakeup_event = wakeup_event 

349 self.is_producer_stopped = is_producer_stopped 

350 self.on_error = on_error 

351 self.sequential_streak = 0 

352 self.offset = 0 

353 self._current_block = b"" 

354 self._current_block_idx = 0 

355 

356 def seek(self, new_offset: int): 

357 """Clears the buffer and resets the internal offset for a hard seek. 

358 

359 Args: 

360 new_offset (int): The byte position the consumer is jumping to. 

361 """ 

362 logger.debug( 

363 "Consumer executing hard seek to offset %d. Clearing internal buffer.", 

364 new_offset, 

365 ) 

366 self.offset = new_offset 

367 self.sequential_streak = 0 

368 self._current_block = b"" 

369 self._current_block_idx = 0 

370 

371 def clear_buffer(self): 

372 """Discards the local byte buffer. Useful during shutdown or resets.""" 

373 logger.debug("Consumer local block buffer cleared.") 

374 self._current_block = b"" 

375 self._current_block_idx = 0 

376 

377 async def _advance(self, size: int, save_data: bool) -> list[bytes]: 

378 """Internal method to advance the offset and optionally extract data. 

379 

380 Handles queue exhaustion, producer wakeups, and streak tracking. 

381 """ 

382 if size <= 0: 

383 return [] 

384 

385 chunks = [] 

386 processed = 0 

387 

388 while processed < size: 

389 available = len(self._current_block) - self._current_block_idx 

390 

391 if not available: 

392 if self.is_producer_stopped() and self.queue.empty(): 

393 logger.debug("Consumer reached EOF.") 

394 break 

395 

396 if self.queue.empty(): 

397 logger.debug("Queue is empty. Waking up producer.") 

398 self.wakeup_event.set() 

399 

400 task = await self.queue.get() 

401 

402 if isinstance(task, Exception): 

403 logger.error("Consumer retrieved an exception: %s", task) 

404 self.on_error(task) 

405 raise task 

406 

407 try: 

408 block = await task 

409 

410 self.sequential_streak += 1 

411 if self.sequential_streak >= 2: 

412 self.wakeup_event.set() 

413 

414 self._current_block = block 

415 self._current_block_idx = 0 

416 available = len(self._current_block) 

417 except asyncio.CancelledError: 

418 raise 

419 except Exception as e: 

420 logger.error("Consumer caught an error: %s", e, exc_info=True) 

421 self.on_error(e) 

422 raise e 

423 

424 if not self._current_block: 

425 break 

426 

427 needed = size - processed 

428 take = min(needed, available) 

429 

430 if save_data: 

431 if take == len(self._current_block) and self._current_block_idx == 0: 

432 chunk = self._current_block 

433 else: 

434 # Native Python slicing was GIL bound in my experiments. 

435 chunk = await asyncio.to_thread( 

436 _fast_slice, self._current_block, self._current_block_idx, take 

437 ) 

438 chunks.append(chunk) 

439 

440 self._current_block_idx += take 

441 processed += take 

442 self.offset += take 

443 

444 return chunks 

445 

446 async def consume(self, size: int) -> bytes: 

447 """Pulls exactly 'size' bytes from the local block or the task queue. 

448 

449 If the local block is exhausted, this will wait on the queue for the next 

450 available chunk of data. 

451 

452 Args: 

453 size (int): The exact number of bytes to retrieve. 

454 

455 Returns: 

456 bytes: The requested bytes. This may be shorter than 'size' if EOF is reached. 

457 

458 Raises: 

459 Exception: Re-raises any exceptions encountered by the producer fetch tasks. 

460 """ 

461 if size <= 0: 

462 return b"" 

463 

464 chunks = await self._advance(size, save_data=True) 

465 

466 if not chunks: 

467 return b"" 

468 

469 if len(chunks) == 1: 

470 return chunks[0] 

471 

472 return await asyncio.to_thread(b"".join, chunks) 

473 

474 async def skip(self, size: int) -> None: 

475 """Advances the consumer offset without allocating memory.""" 

476 await self._advance(size, save_data=False) 

477 

478 

479class BackgroundPrefetcher: 

480 """Orchestrator that manages reading behavior and coordinates background work. 

481 

482 This acts as the main public interface for the file reader. It tracks the 

483 user's reading history, routes seek operations, and links the producer's 

484 network tasks with the consumer's data slicing logic. 

485 """ 

486 

487 def __init__(self, fetcher, size: int, concurrency: int, max_prefetch_size=None): 

488 """Initializes the background prefetcher. 

489 

490 Args: 

491 fetcher (Callable): A coroutine of the form `f(start, end)` which gets bytes from the remote. 

492 size (int): Total byte size of the file being read. 

493 concurrency (int): Number of concurrent network requests to use for large chunks. 

494 max_prefetch_size (int, optional): Maximum bytes to prefetch ahead of the current user offset. 

495 

496 Raises: 

497 ValueError: If max_prefetch_size is provided but is not a positive integer. 

498 """ 

499 logger.debug( 

500 "Starting BackgroundPrefetcher. Size: %d, Concurrency: %d, Max Prefetch: %s", 

501 size, 

502 concurrency, 

503 max_prefetch_size, 

504 ) 

505 self.size = size 

506 self.concurrency = concurrency 

507 

508 if max_prefetch_size is not None and max_prefetch_size <= 0: 

509 logger.error("Invalid max_prefetch_size provided: %s", max_prefetch_size) 

510 raise ValueError( 

511 "max_prefetch_size should be a positive integer to use adaptive prefetching!" 

512 ) 

513 

514 self.loop = fsspec.asyn.get_loop() 

515 self._lock = threading.Lock() 

516 self._error = None 

517 self.is_stopped = False 

518 self.queue = asyncio.Queue() 

519 self.wakeup_event = asyncio.Event() 

520 self.user_offset = 0 

521 self.read_tracker = RunningAverageTracker(maxlen=10) 

522 

523 self.consumer = PrefetchConsumer( 

524 queue=self.queue, 

525 wakeup_event=self.wakeup_event, 

526 is_producer_stopped=self._is_producer_stopped, 

527 on_error=self._set_error, 

528 ) 

529 

530 self.producer = PrefetchProducer( 

531 fetcher=fetcher, 

532 size=self.size, 

533 concurrency=self.concurrency, 

534 queue=self.queue, 

535 wakeup_event=self.wakeup_event, 

536 get_user_offset=lambda: self.consumer.offset, 

537 get_io_size=self._get_adaptive_io_size, 

538 get_sequential_streak=lambda: self.consumer.sequential_streak, 

539 on_error=self._set_error, 

540 user_max_prefetch_size=max_prefetch_size, 

541 ) 

542 

543 async def _start(): 

544 self.producer.start() 

545 

546 fsspec.asyn.sync(self.loop, _start) 

547 logger.debug("BackgroundPrefetcher initialization complete.") 

548 

549 def __enter__(self): 

550 """Context manager entry point.""" 

551 return self 

552 

553 def __exit__(self, exc_type, exc_val, exc_tb): 

554 """Context manager exit point. Ensures the prefetcher is cleanly closed.""" 

555 self.close() 

556 

557 def _get_adaptive_io_size(self) -> int: 

558 return self.read_tracker.average 

559 

560 def _is_producer_stopped(self) -> bool: 

561 return self.producer.is_stopped if hasattr(self, "producer") else True 

562 

563 def _set_error(self, e: Exception): 

564 logger.error("Global error state set in BackgroundPrefetcher: %s", e) 

565 self._error = e 

566 

567 async def _restart_producer(self, new_offset: int): 

568 logger.debug( 

569 "Handling seek request. Restarting producer at offset: %d", new_offset 

570 ) 

571 self._error = None 

572 await self.producer.restart(new_offset) 

573 self.consumer.seek(new_offset) 

574 self.read_tracker.clear() 

575 

576 async def _async_fetch(self, start, end): 

577 logger.debug("Executing _async_fetch for range %d - %d.", start, end) 

578 

579 if start != self.user_offset: 

580 if self.user_offset < start <= self.producer.current_offset: 

581 logger.debug( 

582 "Soft seek detected. Skipping ahead from %d to %d.", 

583 self.user_offset, 

584 start, 

585 ) 

586 skip_amount = start - self.user_offset 

587 await self.consumer.skip(skip_amount) 

588 self.user_offset = start 

589 else: 

590 logger.debug( 

591 "Hard seek detected. Moving user offset from %d to %d.", 

592 self.user_offset, 

593 start, 

594 ) 

595 self.user_offset = start 

596 await self._restart_producer(start) 

597 

598 requested_size = end - start 

599 self.read_tracker.add(requested_size) 

600 

601 chunk = await self.consumer.consume(requested_size) 

602 self.user_offset += len(chunk) 

603 

604 logger.debug("Completed _async_fetch. Returned %d bytes.", len(chunk)) 

605 return chunk 

606 

607 def _fetch(self, start: int | None, end: int | None) -> bytes: 

608 if start is None: 

609 start = 0 

610 if end is None: 

611 end = self.size 

612 

613 end = min(end, self.size) 

614 logger.debug( 

615 "Synchronous _fetch called for bounds start=%s, end=%s.", start, end 

616 ) 

617 

618 if start >= self.size or start >= end: 

619 logger.warning( 

620 "Invalid bounds or EOF reached in _fetch. Start: %d, End: %d, Size: %d", 

621 start, 

622 end, 

623 self.size, 

624 ) 

625 return b"" 

626 

627 with self._lock: 

628 if self._error: 

629 logger.error("Cannot fetch data: instance has an active error state.") 

630 raise self._error 

631 

632 if self.is_stopped: 

633 logger.error( 

634 "Cannot fetch data: BackgroundPrefetcher is stopped or closed." 

635 ) 

636 raise RuntimeError( 

637 "The file instance has been closed. This can occur if a close operation " 

638 "is executed concurrently while a read operation is still in progress." 

639 ) 

640 

641 try: 

642 result = fsspec.asyn.sync(self.loop, self._async_fetch, start, end) 

643 except Exception as e: 

644 logger.error( 

645 "Exception raised during synchronous fetch: %s", e, exc_info=True 

646 ) 

647 self.is_stopped = True 

648 self._error = e 

649 raise 

650 

651 if self.is_stopped: 

652 logger.error("Instance was stopped mid-fetch operation.") 

653 raise RuntimeError( 

654 "The file instance has been closed. This can occur if a close operation " 

655 "is executed concurrently while a read operation is still in progress." 

656 ) 

657 

658 return result 

659 

660 def close(self): 

661 """Safely shuts down the prefetcher. 

662 

663 This cancels all background network tasks and blocks until everything 

664 is completely cleaned up. It also clears the internal consumer buffer. 

665 """ 

666 logger.debug("Closing BackgroundPrefetcher and cleaning up resources.") 

667 if self.is_stopped: 

668 logger.debug( 

669 "BackgroundPrefetcher is already stopped. Skipping close operation." 

670 ) 

671 return 

672 

673 self.is_stopped = True 

674 with self._lock: 

675 fsspec.asyn.sync(self.loop, self.producer.stop) 

676 self.consumer.clear_buffer() 

677 logger.debug("BackgroundPrefetcher closed successfully.")