Coverage for intelligence_toolkit/detect_entity_networks/index_and_infer.py: 100%
62 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1from collections import defaultdict
2from typing import Any
4import networkx as nx
5import numpy as np
6import polars as pl
7from sklearn.neighbors import NearestNeighbors
9import intelligence_toolkit.detect_entity_networks.config as config
10from intelligence_toolkit.AI.base_embedder import BaseEmbedder
11from intelligence_toolkit.AI.openai_configuration import OpenAIConfiguration
12from intelligence_toolkit.AI.openai_embedder import OpenAIEmbedder
13from intelligence_toolkit.AI.utils import hash_text
14from intelligence_toolkit.detect_entity_networks.config import (
15 ENTITY_LABEL,
16 SIMILARITY_THRESHOLD_MAX,
17 SIMILARITY_THRESHOLD_MIN,
18)
19from intelligence_toolkit.helpers.constants import ATTRIBUTE_VALUE_SEPARATOR
20from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
23async def index_nodes(
24 indexed_node_types: list[str],
25 main_graph: nx.Graph,
26 callbacks: list[ProgressBatchCallback] | None = None,
27 functions_embedder: BaseEmbedder | None = None,
28 openai_configuration: OpenAIConfiguration | None = None,
29 save_cache=True,
30):
31 if len(indexed_node_types) == 0:
32 msg = "No node types to index"
33 raise ValueError(msg)
34 text_types = [
35 (node, data["type"])
36 for node, data in main_graph.nodes(data=True)
37 if data["type"] in indexed_node_types
38 ]
39 text_types.sort()
40 texts = [text_type[0] for text_type in text_types]
42 data = [
43 {
44 "hash": hash_text(text),
45 "text": text,
46 "additional_details": {"type": text_type},
47 }
48 for text, text_type in text_types
49 ]
51 if openai_configuration is None:
52 openai_configuration = OpenAIConfiguration()
53 if functions_embedder is None:
54 functions_embedder = OpenAIEmbedder(openai_configuration, config.cache_name)
55 data_embeddings = await functions_embedder.embed_store_many(
56 data,
57 callbacks,
58 save_cache,
59 )
61 # sort data_embeddings by text
62 data_embeddings.sort(key=lambda x: x["text"])
63 embeddings = [np.array(d["vector"]) for d in data_embeddings]
65 vals = [(n, t, e) for (n, t), e in zip(text_types, embeddings, strict=False)]
66 edf = pl.DataFrame(vals, schema=["text", "type", "vector"])
68 edf = edf.filter(pl.col("text").is_in(texts))
69 embedded_texts = edf["text"].to_list()
70 nbrs = NearestNeighbors(
71 n_neighbors=20,
72 n_jobs=1,
73 algorithm="auto",
74 leaf_size=20,
75 metric="cosine",
76 ).fit(embeddings)
78 (
79 nearest_text_distances,
80 nearest_text_indices,
81 ) = nbrs.kneighbors(embeddings)
83 return embedded_texts, nearest_text_distances, nearest_text_indices
86def infer_nodes(
87 similarity_threshold: float,
88 embedded_texts: list[str],
89 nearest_text_indices: list[list[int]],
90 nearest_text_distances: list[list[float]],
91 progress_callbacks: list[ProgressBatchCallback] | None = None,
92) -> defaultdict[Any, set]:
93 inferred_links = defaultdict(set)
94 if (
95 similarity_threshold < SIMILARITY_THRESHOLD_MIN
96 or similarity_threshold > SIMILARITY_THRESHOLD_MAX
97 ):
98 msg = f"Similarity threshold must be between {SIMILARITY_THRESHOLD_MIN} and {SIMILARITY_THRESHOLD_MAX}"
99 raise ValueError(msg)
101 for ix in range(len(embedded_texts)):
102 if progress_callbacks:
103 for cb in progress_callbacks:
104 cb.on_batch_change(ix, len(embedded_texts), "Infering links...")
106 near_is = nearest_text_indices[ix]
107 near_ds = nearest_text_distances[ix]
108 nearest = zip(near_is, near_ds, strict=False)
109 for near_i, near_d in nearest:
110 if (near_i != ix and near_d <= similarity_threshold) and embedded_texts[
111 ix
112 ] != embedded_texts[near_i]:
113 inferred_links[embedded_texts[ix]].add(embedded_texts[near_i])
114 inferred_links[embedded_texts[near_i]].add(embedded_texts[ix])
116 return inferred_links
119def create_inferred_links(inferred_links: defaultdict[Any, set]) -> list[tuple]:
120 return [
121 (text, n) for text, near in inferred_links.items() for n in near if text < n
122 ]
125async def index_and_infer(
126 indexed_node_types: list[str],
127 main_graph: nx.Graph,
128 network_similarity_threshold: float,
129 callbacks: list[ProgressBatchCallback] | None = None,
130 functions_embedder: BaseEmbedder | None = None,
131 openai_configuration: OpenAIConfiguration | None = None,
132 save_cache=True,
133) -> tuple[defaultdict[set], int]:
134 if not len(main_graph.nodes()):
135 msg = "Graph is empty"
136 raise ValueError(msg)
138 (
139 embedded_texts,
140 nearest_text_distances,
141 nearest_text_indices,
142 ) = await index_nodes(
143 indexed_node_types,
144 main_graph,
145 callbacks,
146 functions_embedder,
147 openai_configuration,
148 save_cache,
149 )
151 inferred_links = infer_nodes(
152 network_similarity_threshold,
153 embedded_texts,
154 nearest_text_indices,
155 nearest_text_distances,
156 callbacks,
157 )
159 return inferred_links, len(embedded_texts)