Coverage for src/dataknobs_data/backends/memory.py: 33%

224 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""In-memory database backend implementation.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import threading 

7import uuid 

8from collections import OrderedDict 

9from typing import Any, TYPE_CHECKING 

10 

11from dataknobs_config import ConfigurableBase 

12 

13from ..database import AsyncDatabase, SyncDatabase 

14from ..query_logic import ComplexQuery 

15from ..streaming import AsyncStreamingMixin, StreamConfig, StreamingMixin, StreamResult 

16from ..vector import VectorOperationsMixin 

17from ..vector.bulk_embed_mixin import BulkEmbedMixin 

18from ..vector.python_vector_search import PythonVectorSearchMixin 

19from .sqlite_mixins import SQLiteVectorSupport 

20from .vector_config_mixin import VectorConfigMixin 

21 

22if TYPE_CHECKING: 

23 from collections.abc import AsyncIterator, Iterator 

24 from ..query import Query 

25 from ..records import Record 

26 

27 

28class AsyncMemoryDatabase( # type: ignore[misc] 

29 AsyncDatabase, 

30 AsyncStreamingMixin, 

31 ConfigurableBase, 

32 VectorConfigMixin, # Parse vector config 

33 SQLiteVectorSupport, # Provides _compute_similarity 

34 PythonVectorSearchMixin, # Provides python_vector_search_async 

35 BulkEmbedMixin, # Bulk embedding operations 

36 VectorOperationsMixin # Standard vector interface 

37): 

38 """Async in-memory database implementation.""" 

39 

40 def __init__(self, config: dict[str, Any] | None = None): 

41 super().__init__(config) 

42 self._storage: OrderedDict[str, Record] = OrderedDict() 

43 self._lock = asyncio.Lock() 

44 

45 # Initialize vector support 

46 self._parse_vector_config(config or {}) 

47 self._init_vector_state() # From SQLiteVectorSupport 

48 

49 @classmethod 

50 def from_config(cls, config: dict) -> AsyncMemoryDatabase: 

51 """Create from config dictionary.""" 

52 return cls(config) 

53 

54 

55 def _generate_id(self) -> str: 

56 """Generate a unique ID for a record.""" 

57 return str(uuid.uuid4()) 

58 

59 async def create(self, record: Record) -> str: 

60 """Create a new record in memory.""" 

61 async with self._lock: 

62 # Use centralized method to prepare record 

63 record_copy, storage_id = self._prepare_record_for_storage(record) 

64 

65 # Store the record 

66 self._storage[storage_id] = record_copy 

67 return storage_id 

68 

69 async def read(self, id: str) -> Record | None: 

70 """Read a record from memory.""" 

71 async with self._lock: 

72 record = self._storage.get(id) 

73 # Use centralized method to prepare record 

74 return self._prepare_record_from_storage(record, id) 

75 

76 async def update(self, id: str, record: Record) -> bool: 

77 """Update a record in memory.""" 

78 async with self._lock: 

79 if id in self._storage: 

80 self._storage[id] = record.copy(deep=True) 

81 return True 

82 return False 

83 

84 async def delete(self, id: str) -> bool: 

85 """Delete a record from memory.""" 

86 async with self._lock: 

87 if id in self._storage: 

88 del self._storage[id] 

89 return True 

90 return False 

91 

92 async def exists(self, id: str) -> bool: 

93 """Check if a record exists in memory.""" 

94 async with self._lock: 

95 return id in self._storage 

96 

97 async def upsert(self, id: str, record: Record) -> str: 

98 """Update or insert a record with the specified ID.""" 

99 async with self._lock: 

100 self._storage[id] = record.copy(deep=True) 

101 return id 

102 

103 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

104 """Search for records matching the query.""" 

105 # Handle ComplexQuery using base class implementation 

106 if isinstance(query, ComplexQuery): 

107 return await self._search_with_complex_query(query) 

108 

109 async with self._lock: 

110 results = [] 

111 

112 for id, record in self._storage.items(): 

113 # Apply filters 

114 matches = True 

115 for filter in query.filters: 

116 field_value = record.get_value(filter.field) 

117 if not filter.matches(field_value): 

118 matches = False 

119 break 

120 

121 if matches: 

122 results.append((id, record)) 

123 

124 # Use the helper method from base class 

125 return self._process_search_results(results, query, deep_copy=True) 

126 

127 async def _count_all(self) -> int: 

128 """Count all records in memory.""" 

129 async with self._lock: 

130 return len(self._storage) 

131 

132 async def clear(self) -> int: 

133 """Clear all records from memory.""" 

134 async with self._lock: 

135 count = len(self._storage) 

136 self._storage.clear() 

137 return count 

138 

139 async def create_batch(self, records: list[Record]) -> list[str]: 

140 """Create multiple records efficiently.""" 

141 async with self._lock: 

142 ids = [] 

143 for record in records: 

144 # Use centralized method to prepare record 

145 record_copy, storage_id = self._prepare_record_for_storage(record) 

146 

147 # Store the record 

148 self._storage[storage_id] = record_copy 

149 ids.append(storage_id) 

150 return ids 

151 

152 async def read_batch(self, ids: list[str]) -> list[Record | None]: 

153 """Read multiple records efficiently.""" 

154 async with self._lock: 

155 results = [] 

156 for id in ids: 

157 record = self._storage.get(id) 

158 # Use centralized method to prepare record 

159 results.append(self._prepare_record_from_storage(record, id)) 

160 return results 

161 

162 async def delete_batch(self, ids: list[str]) -> list[bool]: 

163 """Delete multiple records efficiently.""" 

164 async with self._lock: 

165 results = [] 

166 for id in ids: 

167 if id in self._storage: 

168 del self._storage[id] 

169 results.append(True) 

170 else: 

171 results.append(False) 

172 return results 

173 

174 async def stream_read( 

175 self, 

176 query: Query | None = None, 

177 config: StreamConfig | None = None 

178 ) -> AsyncIterator[Record]: 

179 """Stream records from memory.""" 

180 config = config or StreamConfig() 

181 

182 # Get all matching records 

183 if query: 

184 records = await self.search(query) 

185 else: 

186 async with self._lock: 

187 # Ensure records have IDs when getting directly from storage 

188 records = [] 

189 for record_id, record in self._storage.items(): 

190 record_copy = self._ensure_record_id(record, record_id) 

191 records.append(record_copy) 

192 

193 # Yield records in batches 

194 for i in range(0, len(records), config.batch_size): 

195 batch = records[i:i + config.batch_size] 

196 for record in batch: 

197 yield record.copy(deep=True) 

198 # Small yield to prevent blocking 

199 await asyncio.sleep(0) 

200 

201 async def stream_write( 

202 self, 

203 records: AsyncIterator[Record], 

204 config: StreamConfig | None = None 

205 ) -> StreamResult: 

206 """Stream records into memory.""" 

207 # Use the default implementation from mixin 

208 return await self._default_stream_write(records, config) 

209 

210 async def vector_search( 

211 self, 

212 query_vector, 

213 vector_field: str = "embedding", 

214 k: int = 10, 

215 filter=None, 

216 metric=None, 

217 **kwargs 

218 ): 

219 """Perform vector similarity search using Python calculations.""" 

220 return await self.python_vector_search_async( 

221 query_vector=query_vector, 

222 vector_field=vector_field, 

223 k=k, 

224 filter=filter, 

225 metric=metric, 

226 **kwargs 

227 ) 

228 

229 

230class SyncMemoryDatabase( # type: ignore[misc] 

231 SyncDatabase, 

232 StreamingMixin, 

233 ConfigurableBase, 

234 VectorConfigMixin, 

235 SQLiteVectorSupport, 

236 PythonVectorSearchMixin, 

237 BulkEmbedMixin, 

238 VectorOperationsMixin 

239): 

240 """Synchronous in-memory database implementation.""" 

241 

242 def __init__(self, config: dict[str, Any] | None = None): 

243 super().__init__(config) 

244 self._storage: OrderedDict[str, Record] = OrderedDict() 

245 self._lock = threading.RLock() 

246 

247 # Initialize vector support 

248 self._parse_vector_config(config or {}) 

249 self._init_vector_state() 

250 

251 @classmethod 

252 def from_config(cls, config: dict) -> SyncMemoryDatabase: 

253 """Create from config dictionary.""" 

254 return cls(config) 

255 

256 

257 def _generate_id(self) -> str: 

258 """Generate a unique ID for a record.""" 

259 return str(uuid.uuid4()) 

260 

261 def create(self, record: Record) -> str: 

262 """Create a new record in memory.""" 

263 with self._lock: 

264 # Use record's ID if it has one, otherwise generate a new one 

265 id = record.id if record.id else self._generate_id() 

266 self._storage[id] = record.copy(deep=True) 

267 return id 

268 

269 def read(self, id: str) -> Record | None: 

270 """Read a record from memory.""" 

271 with self._lock: 

272 record = self._storage.get(id) 

273 return record.copy(deep=True) if record else None 

274 

275 def update(self, id: str, record: Record) -> bool: 

276 """Update a record in memory.""" 

277 with self._lock: 

278 if id in self._storage: 

279 self._storage[id] = record.copy(deep=True) 

280 return True 

281 return False 

282 

283 def delete(self, id: str) -> bool: 

284 """Delete a record from memory.""" 

285 with self._lock: 

286 if id in self._storage: 

287 del self._storage[id] 

288 return True 

289 return False 

290 

291 def exists(self, id: str) -> bool: 

292 """Check if a record exists in memory.""" 

293 with self._lock: 

294 return id in self._storage 

295 

296 def upsert(self, id: str, record: Record) -> str: 

297 """Update or insert a record with the specified ID.""" 

298 with self._lock: 

299 self._storage[id] = record.copy(deep=True) 

300 return id 

301 

302 def search(self, query: Query | ComplexQuery) -> list[Record]: 

303 """Search for records matching the query.""" 

304 # Handle ComplexQuery using base class implementation 

305 if isinstance(query, ComplexQuery): 

306 return self._search_with_complex_query(query) 

307 

308 with self._lock: 

309 results = [] 

310 

311 for id, record in self._storage.items(): 

312 # Apply filters 

313 matches = True 

314 for filter in query.filters: 

315 field_value = record.get_value(filter.field) 

316 if not filter.matches(field_value): 

317 matches = False 

318 break 

319 

320 if matches: 

321 results.append((id, record)) 

322 

323 # Use the helper method from base class 

324 return self._process_search_results(results, query, deep_copy=True) 

325 

326 def _count_all(self) -> int: 

327 """Count all records in memory.""" 

328 with self._lock: 

329 return len(self._storage) 

330 

331 def clear(self) -> int: 

332 """Clear all records from memory.""" 

333 with self._lock: 

334 count = len(self._storage) 

335 self._storage.clear() 

336 return count 

337 

338 def create_batch(self, records: list[Record]) -> list[str]: 

339 """Create multiple records efficiently.""" 

340 with self._lock: 

341 ids = [] 

342 for record in records: 

343 # Use record's ID if it has one, otherwise generate a new one 

344 id = record.id if record.id else self._generate_id() 

345 self._storage[id] = record.copy(deep=True) 

346 ids.append(id) 

347 return ids 

348 

349 def read_batch(self, ids: list[str]) -> list[Record | None]: 

350 """Read multiple records efficiently.""" 

351 with self._lock: 

352 results = [] 

353 for id in ids: 

354 record = self._storage.get(id) 

355 results.append(record.copy(deep=True) if record else None) 

356 return results 

357 

358 def delete_batch(self, ids: list[str]) -> list[bool]: 

359 """Delete multiple records efficiently.""" 

360 with self._lock: 

361 results = [] 

362 for id in ids: 

363 if id in self._storage: 

364 del self._storage[id] 

365 results.append(True) 

366 else: 

367 results.append(False) 

368 return results 

369 

370 def stream_read( 

371 self, 

372 query: Query | None = None, 

373 config: StreamConfig | None = None 

374 ) -> Iterator[Record]: 

375 """Stream records from memory.""" 

376 config = config or StreamConfig() 

377 

378 # Get all matching records 

379 if query: 

380 records = self.search(query) 

381 else: 

382 with self._lock: 

383 # Ensure records have IDs when getting directly from storage 

384 records = [] 

385 for record_id, record in self._storage.items(): 

386 record_copy = self._ensure_record_id(record, record_id) 

387 records.append(record_copy) 

388 

389 # Yield records in batches 

390 for i in range(0, len(records), config.batch_size): 

391 batch = records[i:i + config.batch_size] 

392 for record in batch: 

393 yield record.copy(deep=True) 

394 

395 def stream_write( 

396 self, 

397 records: Iterator[Record], 

398 config: StreamConfig | None = None 

399 ) -> StreamResult: 

400 """Stream records into memory.""" 

401 # Use the default implementation from mixin 

402 return self._default_stream_write(records, config) 

403 

404 def vector_search( 

405 self, 

406 query_vector, 

407 vector_field: str = "embedding", 

408 k: int = 10, 

409 filter=None, 

410 metric=None, 

411 **kwargs 

412 ): 

413 """Perform vector similarity search using Python calculations.""" 

414 return self.python_vector_search_sync( 

415 query_vector=query_vector, 

416 vector_field=vector_field, 

417 k=k, 

418 filter=filter, 

419 metric=metric, 

420 **kwargs 

421 )