Coverage for src / dataknobs_data / backends / memory.py: 22%

248 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""In-memory database backend implementation.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import threading 

7import uuid 

8from collections import OrderedDict 

9from typing import Any, TYPE_CHECKING 

10 

11from dataknobs_config import ConfigurableBase 

12 

13from ..database import AsyncDatabase, SyncDatabase 

14from ..query_logic import ComplexQuery 

15from ..streaming import AsyncStreamingMixin, StreamConfig, StreamingMixin, StreamResult 

16from ..vector import VectorOperationsMixin 

17from ..vector.bulk_embed_mixin import BulkEmbedMixin 

18from ..vector.python_vector_search import PythonVectorSearchMixin 

19from .sqlite_mixins import SQLiteVectorSupport 

20from .vector_config_mixin import VectorConfigMixin 

21 

22if TYPE_CHECKING: 

23 from collections.abc import AsyncIterator, Iterator 

24 from ..query import Query 

25 from ..records import Record 

26 

27 

28class AsyncMemoryDatabase( # type: ignore[misc] 

29 AsyncDatabase, 

30 AsyncStreamingMixin, 

31 ConfigurableBase, 

32 VectorConfigMixin, # Parse vector config 

33 SQLiteVectorSupport, # Provides _compute_similarity 

34 PythonVectorSearchMixin, # Provides python_vector_search_async 

35 BulkEmbedMixin, # Bulk embedding operations 

36 VectorOperationsMixin # Standard vector interface 

37): 

38 """Async in-memory database implementation.""" 

39 

40 def __init__(self, config: dict[str, Any] | None = None): 

41 super().__init__(config) 

42 self._storage: OrderedDict[str, Record] = OrderedDict() 

43 self._lock = asyncio.Lock() 

44 

45 # Initialize vector support 

46 self._parse_vector_config(config or {}) 

47 self._init_vector_state() # From SQLiteVectorSupport 

48 

49 @classmethod 

50 def from_config(cls, config: dict) -> AsyncMemoryDatabase: 

51 """Create from config dictionary.""" 

52 return cls(config) 

53 

54 

55 def _generate_id(self) -> str: 

56 """Generate a unique ID for a record.""" 

57 return str(uuid.uuid4()) 

58 

59 async def create(self, record: Record) -> str: 

60 """Create a new record in memory.""" 

61 async with self._lock: 

62 # Use centralized method to prepare record 

63 record_copy, storage_id = self._prepare_record_for_storage(record) 

64 

65 # Store the record 

66 self._storage[storage_id] = record_copy 

67 return storage_id 

68 

69 async def read(self, id: str) -> Record | None: 

70 """Read a record from memory.""" 

71 async with self._lock: 

72 record = self._storage.get(id) 

73 # Use centralized method to prepare record 

74 return self._prepare_record_from_storage(record, id) 

75 

76 async def update(self, id: str, record: Record) -> bool: 

77 """Update a record in memory.""" 

78 async with self._lock: 

79 if id in self._storage: 

80 self._storage[id] = record.copy(deep=True) 

81 return True 

82 return False 

83 

84 async def delete(self, id: str) -> bool: 

85 """Delete a record from memory.""" 

86 async with self._lock: 

87 if id in self._storage: 

88 del self._storage[id] 

89 return True 

90 return False 

91 

92 async def exists(self, id: str) -> bool: 

93 """Check if a record exists in memory.""" 

94 async with self._lock: 

95 return id in self._storage 

96 

97 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

98 """Update or insert a record with the specified ID. 

99  

100 Overrides base class to handle memory-specific storage. 

101 """ 

102 # Use base class logic to determine ID and record 

103 if isinstance(id_or_record, str): 

104 id = id_or_record 

105 if record is None: 

106 raise ValueError("Record required when ID is provided") 

107 else: 

108 record = id_or_record 

109 id = record.id 

110 if id is None: 

111 import uuid # type: ignore[unreachable] 

112 id = str(uuid.uuid4()) 

113 record.storage_id = id 

114 

115 # Memory-specific implementation 

116 async with self._lock: 

117 self._storage[id] = record.copy(deep=True) 

118 return id 

119 

120 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

121 """Search for records matching the query.""" 

122 # Handle ComplexQuery using base class implementation 

123 if isinstance(query, ComplexQuery): 

124 return await self._search_with_complex_query(query) 

125 

126 async with self._lock: 

127 results = [] 

128 

129 for id, record in self._storage.items(): 

130 # Apply filters 

131 matches = True 

132 for filter in query.filters: 

133 # Special handling for 'id' field 

134 if filter.field == 'id': 

135 field_value = id 

136 else: 

137 field_value = record.get_value(filter.field) 

138 if not filter.matches(field_value): 

139 matches = False 

140 break 

141 

142 if matches: 

143 results.append((id, record)) 

144 

145 # Use the helper method from base class 

146 return self._process_search_results(results, query, deep_copy=True) 

147 

148 async def _count_all(self) -> int: 

149 """Count all records in memory.""" 

150 async with self._lock: 

151 return len(self._storage) 

152 

153 async def clear(self) -> int: 

154 """Clear all records from memory.""" 

155 async with self._lock: 

156 count = len(self._storage) 

157 self._storage.clear() 

158 return count 

159 

160 async def create_batch(self, records: list[Record]) -> list[str]: 

161 """Create multiple records efficiently.""" 

162 async with self._lock: 

163 ids = [] 

164 for record in records: 

165 # Use centralized method to prepare record 

166 record_copy, storage_id = self._prepare_record_for_storage(record) 

167 

168 # Store the record 

169 self._storage[storage_id] = record_copy 

170 ids.append(storage_id) 

171 return ids 

172 

173 async def read_batch(self, ids: list[str]) -> list[Record | None]: 

174 """Read multiple records efficiently.""" 

175 async with self._lock: 

176 results = [] 

177 for id in ids: 

178 record = self._storage.get(id) 

179 # Use centralized method to prepare record 

180 results.append(self._prepare_record_from_storage(record, id)) 

181 return results 

182 

183 async def delete_batch(self, ids: list[str]) -> list[bool]: 

184 """Delete multiple records efficiently.""" 

185 async with self._lock: 

186 results = [] 

187 for id in ids: 

188 if id in self._storage: 

189 del self._storage[id] 

190 results.append(True) 

191 else: 

192 results.append(False) 

193 return results 

194 

195 async def stream_read( 

196 self, 

197 query: Query | None = None, 

198 config: StreamConfig | None = None 

199 ) -> AsyncIterator[Record]: 

200 """Stream records from memory.""" 

201 config = config or StreamConfig() 

202 

203 # Get all matching records 

204 if query: 

205 records = await self.search(query) 

206 else: 

207 async with self._lock: 

208 # Ensure records have IDs when getting directly from storage 

209 records = [] 

210 for record_id, record in self._storage.items(): 

211 record_copy = self._ensure_record_id(record, record_id) 

212 records.append(record_copy) 

213 

214 # Yield records in batches 

215 for i in range(0, len(records), config.batch_size): 

216 batch = records[i:i + config.batch_size] 

217 for record in batch: 

218 yield record.copy(deep=True) 

219 # Small yield to prevent blocking 

220 await asyncio.sleep(0) 

221 

222 async def stream_write( 

223 self, 

224 records: AsyncIterator[Record], 

225 config: StreamConfig | None = None 

226 ) -> StreamResult: 

227 """Stream records into memory.""" 

228 # Use the default implementation from mixin 

229 return await self._default_stream_write(records, config) 

230 

231 async def vector_search( 

232 self, 

233 query_vector, 

234 vector_field: str = "embedding", 

235 k: int = 10, 

236 filter=None, 

237 metric=None, 

238 **kwargs 

239 ): 

240 """Perform vector similarity search using Python calculations.""" 

241 return await self.python_vector_search_async( 

242 query_vector=query_vector, 

243 vector_field=vector_field, 

244 k=k, 

245 filter=filter, 

246 metric=metric, 

247 **kwargs 

248 ) 

249 

250 

251class SyncMemoryDatabase( # type: ignore[misc] 

252 SyncDatabase, 

253 StreamingMixin, 

254 ConfigurableBase, 

255 VectorConfigMixin, 

256 SQLiteVectorSupport, 

257 PythonVectorSearchMixin, 

258 BulkEmbedMixin, 

259 VectorOperationsMixin 

260): 

261 """Synchronous in-memory database implementation.""" 

262 

263 def __init__(self, config: dict[str, Any] | None = None): 

264 super().__init__(config) 

265 self._storage: OrderedDict[str, Record] = OrderedDict() 

266 self._lock = threading.RLock() 

267 

268 # Initialize vector support 

269 self._parse_vector_config(config or {}) 

270 self._init_vector_state() 

271 

272 @classmethod 

273 def from_config(cls, config: dict) -> SyncMemoryDatabase: 

274 """Create from config dictionary.""" 

275 return cls(config) 

276 

277 

278 def _generate_id(self) -> str: 

279 """Generate a unique ID for a record.""" 

280 return str(uuid.uuid4()) 

281 

282 def create(self, record: Record) -> str: 

283 """Create a new record in memory.""" 

284 with self._lock: 

285 # Use record's ID if it has one, otherwise generate a new one 

286 id = record.id if record.id else self._generate_id() 

287 self._storage[id] = record.copy(deep=True) 

288 return id 

289 

290 def read(self, id: str) -> Record | None: 

291 """Read a record from memory.""" 

292 with self._lock: 

293 record = self._storage.get(id) 

294 return record.copy(deep=True) if record else None 

295 

296 def update(self, id: str, record: Record) -> bool: 

297 """Update a record in memory.""" 

298 with self._lock: 

299 if id in self._storage: 

300 self._storage[id] = record.copy(deep=True) 

301 return True 

302 return False 

303 

304 def delete(self, id: str) -> bool: 

305 """Delete a record from memory.""" 

306 with self._lock: 

307 if id in self._storage: 

308 del self._storage[id] 

309 return True 

310 return False 

311 

312 def exists(self, id: str) -> bool: 

313 """Check if a record exists in memory.""" 

314 with self._lock: 

315 return id in self._storage 

316 

317 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

318 """Update or insert a record with the specified ID. 

319  

320 Overrides base class to handle memory-specific storage. 

321 """ 

322 # Use base class logic to determine ID and record 

323 if isinstance(id_or_record, str): 

324 id = id_or_record 

325 if record is None: 

326 raise ValueError("Record required when ID is provided") 

327 else: 

328 record = id_or_record 

329 id = record.id 

330 if id is None: 

331 import uuid # type: ignore[unreachable] 

332 id = str(uuid.uuid4()) 

333 record.storage_id = id 

334 

335 # Memory-specific implementation 

336 with self._lock: 

337 self._storage[id] = record.copy(deep=True) 

338 return id 

339 

340 def search(self, query: Query | ComplexQuery) -> list[Record]: 

341 """Search for records matching the query.""" 

342 # Handle ComplexQuery using base class implementation 

343 if isinstance(query, ComplexQuery): 

344 return self._search_with_complex_query(query) 

345 

346 with self._lock: 

347 results = [] 

348 

349 for id, record in self._storage.items(): 

350 # Apply filters 

351 matches = True 

352 for filter in query.filters: 

353 # Special handling for 'id' field 

354 if filter.field == 'id': 

355 field_value = id 

356 else: 

357 field_value = record.get_value(filter.field) 

358 if not filter.matches(field_value): 

359 matches = False 

360 break 

361 

362 if matches: 

363 results.append((id, record)) 

364 

365 # Use the helper method from base class 

366 return self._process_search_results(results, query, deep_copy=True) 

367 

368 def _count_all(self) -> int: 

369 """Count all records in memory.""" 

370 with self._lock: 

371 return len(self._storage) 

372 

373 def clear(self) -> int: 

374 """Clear all records from memory.""" 

375 with self._lock: 

376 count = len(self._storage) 

377 self._storage.clear() 

378 return count 

379 

380 def create_batch(self, records: list[Record]) -> list[str]: 

381 """Create multiple records efficiently.""" 

382 with self._lock: 

383 ids = [] 

384 for record in records: 

385 # Use record's ID if it has one, otherwise generate a new one 

386 id = record.id if record.id else self._generate_id() 

387 self._storage[id] = record.copy(deep=True) 

388 ids.append(id) 

389 return ids 

390 

391 def read_batch(self, ids: list[str]) -> list[Record | None]: 

392 """Read multiple records efficiently.""" 

393 with self._lock: 

394 results = [] 

395 for id in ids: 

396 record = self._storage.get(id) 

397 results.append(record.copy(deep=True) if record else None) 

398 return results 

399 

400 def delete_batch(self, ids: list[str]) -> list[bool]: 

401 """Delete multiple records efficiently.""" 

402 with self._lock: 

403 results = [] 

404 for id in ids: 

405 if id in self._storage: 

406 del self._storage[id] 

407 results.append(True) 

408 else: 

409 results.append(False) 

410 return results 

411 

412 def stream_read( 

413 self, 

414 query: Query | None = None, 

415 config: StreamConfig | None = None 

416 ) -> Iterator[Record]: 

417 """Stream records from memory.""" 

418 config = config or StreamConfig() 

419 

420 # Get all matching records 

421 if query: 

422 records = self.search(query) 

423 else: 

424 with self._lock: 

425 # Ensure records have IDs when getting directly from storage 

426 records = [] 

427 for record_id, record in self._storage.items(): 

428 record_copy = self._ensure_record_id(record, record_id) 

429 records.append(record_copy) 

430 

431 # Yield records in batches 

432 for i in range(0, len(records), config.batch_size): 

433 batch = records[i:i + config.batch_size] 

434 for record in batch: 

435 yield record.copy(deep=True) 

436 

437 def stream_write( 

438 self, 

439 records: Iterator[Record], 

440 config: StreamConfig | None = None 

441 ) -> StreamResult: 

442 """Stream records into memory.""" 

443 # Use the default implementation from mixin 

444 return self._default_stream_write(records, config) 

445 

446 def vector_search( 

447 self, 

448 query_vector, 

449 vector_field: str = "embedding", 

450 k: int = 10, 

451 filter=None, 

452 metric=None, 

453 **kwargs 

454 ): 

455 """Perform vector similarity search using Python calculations.""" 

456 return self.python_vector_search_sync( 

457 query_vector=query_vector, 

458 vector_field=vector_field, 

459 k=k, 

460 filter=filter, 

461 metric=metric, 

462 **kwargs 

463 )