Coverage for src / kemi / topics.py: 94%
53 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Topic clustering for memories using local CPU-only methods.
3Requires scikit-learn (optional dependency).
4"""
6import logging
7from typing import Any
9from kemi.models import LifecycleState, MemoryObject
11logger = logging.getLogger(__name__)
14def _sklearn_available() -> bool:
15 try:
16 import sklearn # noqa: F401
18 return True
19 except ImportError:
20 return False
23def cluster_memories(
24 store: Any,
25 user_id: str,
26 n_clusters: int = 3,
27 namespace: str = "default",
28) -> dict[str, list[MemoryObject]]:
29 """Cluster a user's memories into topic groups using KMeans on embeddings.
31 Args:
32 store: StorageAdapter instance.
33 user_id: User ID.
34 n_clusters: Number of clusters. Auto-capped to number of memories.
35 namespace: Memory namespace.
37 Returns:
38 Dict mapping topic label (e.g. "topic_0") to list of MemoryObjects.
39 """
40 if not _sklearn_available():
41 raise ImportError(
42 "scikit-learn is required for topic clustering. Install with: pip install scikit-learn"
43 )
45 from sklearn.cluster import KMeans
47 memories = store.get_all_by_user(
48 user_id,
49 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING],
50 namespace=namespace,
51 )
53 # Filter memories that have embeddings
54 memories_with_emb = [m for m in memories if m.embedding is not None]
56 if len(memories_with_emb) < 2:
57 if memories_with_emb:
58 return {"topic_0": memories_with_emb}
59 return {}
61 effective_k = min(n_clusters, len(memories_with_emb))
62 if effective_k < 2:
63 effective_k = 2
65 embeddings = [m.embedding for m in memories_with_emb]
67 try:
68 kmeans = KMeans(n_clusters=effective_k, random_state=42, n_init="auto")
69 labels = kmeans.fit_predict(embeddings)
70 except Exception as e:
71 logger.warning(f"KMeans clustering failed: {e}")
72 return {"topic_0": memories_with_emb}
74 clusters: dict[str, list[MemoryObject]] = {}
75 for mem, label in zip(memories_with_emb, labels, strict=False):
76 key = f"topic_{label}"
77 clusters.setdefault(key, []).append(mem)
79 # Sort clusters by size (largest first) and rename by top keywords
80 sorted_clusters = dict(sorted(clusters.items(), key=lambda x: len(x[1]), reverse=True))
82 # Generate topic labels from top TF-like keywords
83 labeled: dict[str, list[MemoryObject]] = {}
84 for idx, (_, mems) in enumerate(sorted_clusters.items()):
85 label = _generate_topic_label(mems, idx)
86 labeled[label] = mems
88 return labeled
91def _generate_topic_label(memories: list[MemoryObject], index: int) -> str:
92 """Generate a human-readable topic label from memory contents."""
93 # Simple TF-like keyword extraction
94 word_freq: dict[str, int] = {}
95 stopwords = {
96 "the",
97 "a",
98 "an",
99 "is",
100 "are",
101 "was",
102 "were",
103 "be",
104 "been",
105 "being",
106 "have",
107 "has",
108 "had",
109 "do",
110 "does",
111 "did",
112 "will",
113 "would",
114 "could",
115 "should",
116 "may",
117 "might",
118 "must",
119 "shall",
120 "can",
121 "need",
122 "dare",
123 "ought",
124 "used",
125 "to",
126 "of",
127 "in",
128 "for",
129 "on",
130 "with",
131 "at",
132 "by",
133 "from",
134 "as",
135 "into",
136 "through",
137 "during",
138 "before",
139 "after",
140 "above",
141 "below",
142 "between",
143 "under",
144 "and",
145 "but",
146 "or",
147 "yet",
148 "so",
149 "if",
150 "because",
151 "although",
152 "though",
153 "while",
154 "where",
155 "i",
156 "you",
157 "he",
158 "she",
159 "it",
160 "we",
161 "they",
162 "me",
163 "him",
164 "her",
165 "us",
166 "them",
167 "my",
168 "your",
169 "his",
170 "its",
171 "our",
172 "their",
173 "mine",
174 "yours",
175 "hers",
176 "ours",
177 "theirs",
178 "this",
179 "that",
180 "these",
181 "those",
182 "am",
183 }
185 for mem in memories:
186 for word in mem.content.lower().split():
187 clean = word.strip(".,!?;:'\"()-")
188 if len(clean) > 3 and clean not in stopwords:
189 word_freq[clean] = word_freq.get(clean, 0) + 1
191 top_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:2]
192 if top_words:
193 label = " ".join(w[0].capitalize() for w in top_words)
194 else:
195 label = f"Topic {index + 1}"
197 return label