Coverage for src/dataknobs_data/vector/optimizations.py: 0%

212 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""Vector store optimization and performance enhancements.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from collections import deque 

8from dataclasses import dataclass 

9from threading import Lock 

10from typing import Any, TYPE_CHECKING 

11 

12import numpy as np 

13 

14if TYPE_CHECKING: 

15 from collections.abc import Callable 

16 from .types import DistanceMetric 

17 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22@dataclass 

23class BatchConfig: 

24 """Configuration for batch operations.""" 

25 

26 size: int = 100 

27 max_queue_size: int = 10000 

28 flush_interval: float = 1.0 # seconds 

29 parallel_workers: int = 4 

30 retry_on_failure: bool = True 

31 max_retries: int = 3 

32 

33 

34@dataclass 

35class ConnectionPoolConfig: 

36 """Configuration for connection pooling.""" 

37 

38 min_connections: int = 1 

39 max_connections: int = 10 

40 connection_timeout: float = 30.0 

41 idle_timeout: float = 300.0 

42 recycle_timeout: float = 3600.0 

43 

44 

45class BatchProcessor: 

46 """Handles batch processing of vector operations.""" 

47 

48 def __init__(self, config: BatchConfig | None = None): 

49 """Initialize the batch processor. 

50  

51 Args: 

52 config: Batch configuration 

53 """ 

54 self.config = config or BatchConfig() 

55 self.queue: deque = deque(maxlen=self.config.max_queue_size) 

56 self.lock = Lock() 

57 self.processing = False 

58 self._flush_task: asyncio.Task | None = None 

59 

60 async def add(self, item: Any, callback: Callable | None = None) -> None: 

61 """Add an item to the batch queue. 

62  

63 Args: 

64 item: Item to process 

65 callback: Optional callback when item is processed 

66 """ 

67 should_flush = False 

68 with self.lock: 

69 self.queue.append((item, callback)) 

70 # Check if we should flush 

71 if len(self.queue) >= self.config.size: 

72 should_flush = True 

73 

74 # Flush outside of lock to avoid deadlock 

75 if should_flush: 

76 await self.flush() 

77 

78 async def flush(self) -> int: 

79 """Process all items in the queue. 

80  

81 Returns: 

82 Number of items processed 

83 """ 

84 items_to_process = [] 

85 

86 with self.lock: 

87 # Get batch of items 

88 batch_size = min(len(self.queue), self.config.size) 

89 for _ in range(batch_size): 

90 if self.queue: 

91 items_to_process.append(self.queue.popleft()) 

92 

93 if not items_to_process: 

94 return 0 

95 

96 # Process items in parallel if configured 

97 if self.config.parallel_workers > 1: 

98 return await self._process_parallel(items_to_process) 

99 else: 

100 return await self._process_sequential(items_to_process) 

101 

102 async def _process_sequential(self, items: list[tuple]) -> int: 

103 """Process items sequentially. 

104  

105 Args: 

106 items: List of (item, callback) tuples 

107  

108 Returns: 

109 Number of items processed 

110 """ 

111 processed = 0 

112 for item, callback in items: 

113 try: 

114 if callback: 

115 if asyncio.iscoroutinefunction(callback): 

116 await callback(item) 

117 else: 

118 callback(item) 

119 processed += 1 

120 except Exception as e: 

121 logger.error(f"Error processing item: {e}") 

122 if self.config.retry_on_failure: 

123 # Re-queue for retry 

124 with self.lock: 

125 self.queue.append((item, callback)) 

126 

127 return processed 

128 

129 async def _process_parallel(self, items: list[tuple]) -> int: 

130 """Process items in parallel. 

131  

132 Args: 

133 items: List of (item, callback) tuples 

134  

135 Returns: 

136 Number of items processed 

137 """ 

138 # Split items into chunks for parallel processing 

139 chunk_size = len(items) // self.config.parallel_workers 

140 if chunk_size == 0: 

141 chunk_size = 1 

142 

143 chunks = [ 

144 items[i:i + chunk_size] 

145 for i in range(0, len(items), chunk_size) 

146 ] 

147 

148 # Process chunks in parallel 

149 tasks = [ 

150 self._process_sequential(chunk) 

151 for chunk in chunks 

152 ] 

153 

154 results = await asyncio.gather(*tasks, return_exceptions=True) 

155 

156 # Count successful processes 

157 processed = sum( 

158 r for r in results 

159 if isinstance(r, int) 

160 ) 

161 

162 return processed 

163 

164 async def start_auto_flush(self) -> None: 

165 """Start automatic flushing at intervals.""" 

166 if self._flush_task is None or self._flush_task.done(): 

167 self._flush_task = asyncio.create_task(self._auto_flush_loop()) 

168 

169 async def stop_auto_flush(self) -> None: 

170 """Stop automatic flushing.""" 

171 if self._flush_task and not self._flush_task.done(): 

172 self._flush_task.cancel() 

173 try: 

174 await self._flush_task 

175 except asyncio.CancelledError: 

176 pass 

177 

178 async def _auto_flush_loop(self) -> None: 

179 """Background task for automatic flushing.""" 

180 while True: 

181 try: 

182 await asyncio.sleep(self.config.flush_interval) 

183 await self.flush() 

184 except asyncio.CancelledError: 

185 break 

186 except Exception as e: 

187 logger.error(f"Error in auto-flush: {e}") 

188 

189 

190class VectorOptimizer: 

191 """Optimizes vector operations for better performance.""" 

192 

193 @staticmethod 

194 def optimize_batch_size( 

195 num_vectors: int, 

196 vector_dim: int, 

197 available_memory: int = 1024 * 1024 * 1024 # 1GB default 

198 ) -> int: 

199 """Calculate optimal batch size based on available resources. 

200  

201 Args: 

202 num_vectors: Total number of vectors 

203 vector_dim: Dimension of each vector 

204 available_memory: Available memory in bytes 

205  

206 Returns: 

207 Optimal batch size 

208 """ 

209 # Estimate memory per vector (float32 = 4 bytes) 

210 bytes_per_vector = vector_dim * 4 

211 

212 # Add overhead for metadata and indexing (estimate 50% overhead) 

213 bytes_per_vector = int(bytes_per_vector * 1.5) 

214 

215 # Calculate max vectors that fit in memory 

216 max_batch = available_memory // bytes_per_vector 

217 

218 # Apply reasonable limits 

219 min_batch = 10 

220 max_reasonable = 10000 

221 

222 optimal = min(max_batch, max_reasonable, num_vectors) 

223 optimal = max(optimal, min_batch) 

224 

225 return optimal 

226 

227 @staticmethod 

228 def select_index_type( 

229 num_vectors: int, 

230 vector_dim: int, 

231 metric: DistanceMetric 

232 ) -> dict[str, Any]: 

233 """Select optimal index type based on dataset characteristics. 

234  

235 Args: 

236 num_vectors: Number of vectors 

237 vector_dim: Vector dimensions 

238 metric: Distance metric 

239  

240 Returns: 

241 Index configuration 

242 """ 

243 config = {"metric": metric} 

244 

245 # Small datasets: use flat index for exact search 

246 if num_vectors < 10000: 

247 config["type"] = "flat" 

248 return config 

249 

250 # Medium datasets: use IVF 

251 if num_vectors < 1000000: 

252 # Calculate optimal number of clusters 

253 nlist = int(np.sqrt(num_vectors)) 

254 nlist = min(max(nlist, 100), 4096) 

255 

256 config["type"] = "ivfflat" 

257 config["nlist"] = nlist 

258 config["nprobe"] = min(nlist // 10, 64) 

259 return config 

260 

261 # Large datasets: use HNSW 

262 config["type"] = "hnsw" 

263 config["m"] = 16 # Number of connections 

264 config["ef_construction"] = 200 

265 config["ef_search"] = 50 

266 

267 return config 

268 

269 @staticmethod 

270 def optimize_search_params( 

271 index_type: str, 

272 recall_target: float = 0.95 

273 ) -> dict[str, Any]: 

274 """Optimize search parameters for target recall. 

275  

276 Args: 

277 index_type: Type of index 

278 recall_target: Target recall rate (0-1) 

279  

280 Returns: 

281 Optimized search parameters 

282 """ 

283 params = {} 

284 

285 if index_type == "flat": 

286 # Flat index is always exact 

287 return params 

288 

289 elif index_type == "ivfflat": 

290 # Adjust nprobe based on recall target 

291 if recall_target >= 0.99: 

292 params["nprobe"] = 128 

293 elif recall_target >= 0.95: 

294 params["nprobe"] = 64 

295 elif recall_target >= 0.90: 

296 params["nprobe"] = 32 

297 else: 

298 params["nprobe"] = 16 

299 

300 elif index_type == "hnsw": 

301 # Adjust ef_search based on recall target 

302 if recall_target >= 0.99: 

303 params["ef_search"] = 200 

304 elif recall_target >= 0.95: 

305 params["ef_search"] = 100 

306 elif recall_target >= 0.90: 

307 params["ef_search"] = 50 

308 else: 

309 params["ef_search"] = 32 

310 

311 return params 

312 

313 

314class ConnectionPool: 

315 """Manages a pool of connections for vector stores.""" 

316 

317 def __init__(self, 

318 factory: Callable, 

319 config: ConnectionPoolConfig | None = None): 

320 """Initialize the connection pool. 

321  

322 Args: 

323 factory: Function to create new connections 

324 config: Pool configuration 

325 """ 

326 self.factory = factory 

327 self.config = config or ConnectionPoolConfig() 

328 self.available: deque = deque() 

329 self.in_use: set = set() 

330 self.lock = Lock() 

331 self._closed = False 

332 

333 async def acquire(self) -> Any: 

334 """Acquire a connection from the pool. 

335  

336 Returns: 

337 A connection object 

338 """ 

339 if self._closed: 

340 raise RuntimeError("Connection pool is closed") 

341 

342 with self.lock: 

343 # Try to get an available connection 

344 while self.available: 

345 conn = self.available.popleft() 

346 # TODO: Check if connection is still valid 

347 self.in_use.add(conn) 

348 return conn 

349 

350 # Create new connection if under limit 

351 if len(self.in_use) < self.config.max_connections: 

352 conn = await self.factory() 

353 self.in_use.add(conn) 

354 return conn 

355 

356 # Wait for a connection to become available 

357 retry_count = 0 

358 while retry_count < 100: # Avoid infinite loop 

359 await asyncio.sleep(0.1) 

360 with self.lock: 

361 if self.available: 

362 conn = self.available.popleft() 

363 self.in_use.add(conn) 

364 return conn 

365 retry_count += 1 

366 

367 raise TimeoutError("Could not acquire connection from pool") 

368 

369 async def release(self, conn: Any) -> None: 

370 """Release a connection back to the pool. 

371  

372 Args: 

373 conn: Connection to release 

374 """ 

375 with self.lock: 

376 if conn in self.in_use: 

377 self.in_use.remove(conn) 

378 if not self._closed: 

379 self.available.append(conn) 

380 

381 async def close(self) -> None: 

382 """Close all connections in the pool.""" 

383 self._closed = True 

384 

385 with self.lock: 

386 # Close all connections 

387 all_conns = list(self.available) + list(self.in_use) 

388 self.available.clear() 

389 self.in_use.clear() 

390 

391 # Close connections (if they have close method) 

392 for conn in all_conns: 

393 if hasattr(conn, 'close'): 

394 try: 

395 if asyncio.iscoroutinefunction(conn.close): 

396 await conn.close() 

397 else: 

398 conn.close() 

399 except Exception as e: 

400 logger.error(f"Error closing connection: {e}") 

401 

402 

403class QueryOptimizer: 

404 """Optimizes vector queries for better performance.""" 

405 

406 @staticmethod 

407 def should_use_index( 

408 num_vectors: int, 

409 k: int, 

410 filter_selectivity: float = 1.0 

411 ) -> bool: 

412 """Determine if index should be used for query. 

413  

414 Args: 

415 num_vectors: Total number of vectors 

416 k: Number of results to return 

417 filter_selectivity: Estimated filter selectivity (0-1) 

418  

419 Returns: 

420 True if index should be used 

421 """ 

422 # If we're retrieving most vectors, scan might be faster 

423 if k / num_vectors > 0.1: 

424 return False 

425 

426 # If filter is very selective, scan filtered results 

427 if filter_selectivity < 0.01: 

428 return False 

429 

430 # Otherwise use index 

431 return True 

432 

433 @staticmethod 

434 def optimize_reranking( 

435 initial_k: int, 

436 final_k: int, 

437 rerank_factor: float = 3.0 

438 ) -> int: 

439 """Calculate optimal number of candidates for reranking. 

440  

441 Args: 

442 initial_k: Initial number of results 

443 final_k: Final number of results after reranking 

444 rerank_factor: Multiplier for candidates 

445  

446 Returns: 

447 Optimal number of candidates 

448 """ 

449 candidates = int(final_k * rerank_factor) 

450 

451 # Apply reasonable limits 

452 min_candidates = final_k * 2 

453 max_candidates = min(initial_k, final_k * 10) 

454 

455 candidates = max(candidates, min_candidates) 

456 candidates = min(candidates, max_candidates) 

457 

458 return candidates 

459 

460 

461# Export main classes 

462__all__ = [ 

463 "BatchConfig", 

464 "BatchProcessor", 

465 "ConnectionPool", 

466 "ConnectionPoolConfig", 

467 "QueryOptimizer", 

468 "VectorOptimizer", 

469]