Coverage for intelligence_toolkit/detect_entity_networks/index_and_infer.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1from collections import defaultdict 

2from typing import Any 

3 

4import networkx as nx 

5import numpy as np 

6import polars as pl 

7from sklearn.neighbors import NearestNeighbors 

8 

9import intelligence_toolkit.detect_entity_networks.config as config 

10from intelligence_toolkit.AI.base_embedder import BaseEmbedder 

11from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration 

12from intelligence_toolkit.AI.openai_embedder import OpenAIEmbedder 

13from intelligence_toolkit.AI.utils import hash_text 

14from intelligence_toolkit.detect_entity_networks.config import ( 

15 ENTITY_LABEL, 

16 SIMILARITY_THRESHOLD_MAX, 

17 SIMILARITY_THRESHOLD_MIN, 

18) 

19from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR 

20from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

21 

22 

23async def index_nodes( 

24 indexed_node_types: list[str], 

25 main_graph: nx.Graph, 

26 callbacks: list[ProgressBatchCallback] | None = None, 

27 functions_embedder: BaseEmbedder | None = None, 

28 openai_configuration: OpenAIConfiguration | None = None, 

29 save_cache=True, 

30): 

31 if len(indexed_node_types) == 0: 

32 msg = "No node types to index" 

33 raise ValueError(msg) 

34 text_types = [ 

35 (node, data["type"]) 

36 for node, data in main_graph.nodes(data=True) 

37 if data["type"] in indexed_node_types 

38 ] 

39 text_types.sort() 

40 texts = [text_type[0] for text_type in text_types] 

41 

42 data = [ 

43 { 

44 "hash": hash_text(text), 

45 "text": text, 

46 "additional_details": {"type": text_type}, 

47 } 

48 for text, text_type in text_types 

49 ] 

50 

51 if openai_configuration is None: 

52 openai_configuration = OpenAIConfiguration() 

53 if functions_embedder is None: 

54 functions_embedder = OpenAIEmbedder(openai_configuration, config.cache_name) 

55 data_embeddings = await functions_embedder.embed_store_many( 

56 data, 

57 callbacks, 

58 save_cache, 

59 ) 

60 

61 # sort data_embeddings by text 

62 data_embeddings.sort(key=lambda x: x["text"]) 

63 embeddings = [np.array(d["vector"]) for d in data_embeddings] 

64 

65 vals = [(n, t, e) for (n, t), e in zip(text_types, embeddings, strict=False)] 

66 edf = pl.DataFrame(vals, schema=["text", "type", "vector"]) 

67 

68 edf = edf.filter(pl.col("text").is_in(texts)) 

69 embedded_texts = edf["text"].to_list() 

70 nbrs = NearestNeighbors( 

71 n_neighbors=20, 

72 n_jobs=1, 

73 algorithm="auto", 

74 leaf_size=20, 

75 metric="cosine", 

76 ).fit(embeddings) 

77 

78 ( 

79 nearest_text_distances, 

80 nearest_text_indices, 

81 ) = nbrs.kneighbors(embeddings) 

82 

83 return embedded_texts, nearest_text_distances, nearest_text_indices 

84 

85 

86def infer_nodes( 

87 similarity_threshold: float, 

88 embedded_texts: list[str], 

89 nearest_text_indices: list[list[int]], 

90 nearest_text_distances: list[list[float]], 

91 progress_callbacks: list[ProgressBatchCallback] | None = None, 

92) -> defaultdict[Any, set]: 

93 inferred_links = defaultdict(set) 

94 if ( 

95 similarity_threshold < SIMILARITY_THRESHOLD_MIN 

96 or similarity_threshold > SIMILARITY_THRESHOLD_MAX 

97 ): 

98 msg = f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}" 

99 raise ValueError(msg) 

100 

101 for ix in range(len(embedded_texts)): 

102 if progress_callbacks: 

103 for cb in progress_callbacks: 

104 cb.on_batch_change(ix, len(embedded_texts), "Infering links...") 

105 

106 near_is = nearest_text_indices[ix] 

107 near_ds = nearest_text_distances[ix] 

108 nearest = zip(near_is, near_ds, strict=False) 

109 for near_i, near_d in nearest: 

110 if (near_i != ix and near_d <= similarity_threshold) and embedded_texts[ 

111 ix 

112 ] != embedded_texts[near_i]: 

113 inferred_links[embedded_texts[ix]].add(embedded_texts[near_i]) 

114 inferred_links[embedded_texts[near_i]].add(embedded_texts[ix]) 

115 

116 return inferred_links 

117 

118 

119def create_inferred_links(inferred_links: defaultdict[Any, set]) -> list[tuple]: 

120 return [ 

121 (text, n) for text, near in inferred_links.items() for n in near if text < n 

122 ] 

123 

124 

125async def index_and_infer( 

126 indexed_node_types: list[str], 

127 main_graph: nx.Graph, 

128 network_similarity_threshold: float, 

129 callbacks: list[ProgressBatchCallback] | None = None, 

130 functions_embedder: BaseEmbedder | None = None, 

131 openai_configuration: OpenAIConfiguration | None = None, 

132 save_cache=True, 

133) -> tuple[defaultdict[set], int]: 

134 if not len(main_graph.nodes()): 

135 msg = "Graph is empty" 

136 raise ValueError(msg) 

137 

138 ( 

139 embedded_texts, 

140 nearest_text_distances, 

141 nearest_text_indices, 

142 ) = await index_nodes( 

143 indexed_node_types, 

144 main_graph, 

145 callbacks, 

146 functions_embedder, 

147 openai_configuration, 

148 save_cache, 

149 ) 

150 

151 inferred_links = infer_nodes( 

152 network_similarity_threshold, 

153 embedded_texts, 

154 nearest_text_indices, 

155 nearest_text_distances, 

156 callbacks, 

157 ) 

158 

159 return inferred_links, len(embedded_texts)