Coverage for intelligence_toolkit/AI/vector_store.py: 95%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# # Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# # Licensed under the MIT license. See LICENSE file in the project. 

3# # 

4from typing import Any 

5 

6import duckdb 

7import lancedb 

8import pyarrow as pa 

9from pandas import DataFrame 

10 

11from intelligence_toolkit.helpers.constants import CACHE_PATH 

12 

13table_missing_msg = "Table not initialized" 

14 

15 

16class VectorStore: 

17 table = None 

18 duckdb_data = None 

19 

20 def __init__( 

21 self, 

22 table_name: str | None = None, 

23 path: str = CACHE_PATH, 

24 schema: pa.Schema = None, 

25 ): 

26 self.db_connection = lancedb.connect(path) 

27 if table_name is not None: 

28 self.table = self.db_connection.create_table( 

29 table_name, schema=schema, exist_ok=True 

30 ) 

31 self.duckdb_data = self.table.to_lance() 

32 

33 def save(self, items: list[Any]) -> None: 

34 if self.table is None: 

35 raise ValueError(table_missing_msg) 

36 self.table.add(items) 

37 

38 def search_by_column(self, texts: list[str] | str, column: str) -> DataFrame: 

39 if self.table is None: 

40 raise ValueError(table_missing_msg) 

41 if isinstance(texts, str): 

42 texts = [texts] 

43 arrow_data = self.duckdb_data 

44 query = f"SELECT DISTINCT * FROM arrow_data WHERE {column} IN {tuple(texts)}" 

45 return duckdb.execute(query).df() 

46 

47 def search_by_vector(self, vector: list[float], k: int = 10) -> list[dict]: 

48 if self.table is None: 

49 raise ValueError(table_missing_msg) 

50 

51 return self.table.search(vector).limit(k).to_list() 

52 

53 def update_duckdb_data(self) -> None: 

54 if self.table is None: 

55 raise ValueError(table_missing_msg) 

56 self.duckdb_data = self.table.to_lance() 

57 

58 def drop_table(self) -> None: 

59 if self.table is None: 

60 raise ValueError(table_missing_msg) 

61 self.db_connection.drop_table(self.table_name) 

62 

63 def drop_db(self): 

64 self.db_connection.drop_database()