Coverage for src / kemi / topics.py: 94%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Topic clustering for memories using local CPU-only methods. 

2 

3Requires scikit-learn (optional dependency). 

4""" 

5 

6import logging 

7from typing import Any 

8 

9from kemi.models import LifecycleState, MemoryObject 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def _sklearn_available() -> bool: 

15 try: 

16 import sklearn # noqa: F401 

17 

18 return True 

19 except ImportError: 

20 return False 

21 

22 

23def cluster_memories( 

24 store: Any, 

25 user_id: str, 

26 n_clusters: int = 3, 

27 namespace: str = "default", 

28) -> dict[str, list[MemoryObject]]: 

29 """Cluster a user's memories into topic groups using KMeans on embeddings. 

30 

31 Args: 

32 store: StorageAdapter instance. 

33 user_id: User ID. 

34 n_clusters: Number of clusters. Auto-capped to number of memories. 

35 namespace: Memory namespace. 

36 

37 Returns: 

38 Dict mapping topic label (e.g. "topic_0") to list of MemoryObjects. 

39 """ 

40 if not _sklearn_available(): 

41 raise ImportError( 

42 "scikit-learn is required for topic clustering. Install with: pip install scikit-learn" 

43 ) 

44 

45 from sklearn.cluster import KMeans 

46 

47 memories = store.get_all_by_user( 

48 user_id, 

49 lifecycle_filter=[LifecycleState.ACTIVE, LifecycleState.DECAYING], 

50 namespace=namespace, 

51 ) 

52 

53 # Filter memories that have embeddings 

54 memories_with_emb = [m for m in memories if m.embedding is not None] 

55 

56 if len(memories_with_emb) < 2: 

57 if memories_with_emb: 

58 return {"topic_0": memories_with_emb} 

59 return {} 

60 

61 effective_k = min(n_clusters, len(memories_with_emb)) 

62 if effective_k < 2: 

63 effective_k = 2 

64 

65 embeddings = [m.embedding for m in memories_with_emb] 

66 

67 try: 

68 kmeans = KMeans(n_clusters=effective_k, random_state=42, n_init="auto") 

69 labels = kmeans.fit_predict(embeddings) 

70 except Exception as e: 

71 logger.warning(f"KMeans clustering failed: {e}") 

72 return {"topic_0": memories_with_emb} 

73 

74 clusters: dict[str, list[MemoryObject]] = {} 

75 for mem, label in zip(memories_with_emb, labels, strict=False): 

76 key = f"topic_{label}" 

77 clusters.setdefault(key, []).append(mem) 

78 

79 # Sort clusters by size (largest first) and rename by top keywords 

80 sorted_clusters = dict(sorted(clusters.items(), key=lambda x: len(x[1]), reverse=True)) 

81 

82 # Generate topic labels from top TF-like keywords 

83 labeled: dict[str, list[MemoryObject]] = {} 

84 for idx, (_, mems) in enumerate(sorted_clusters.items()): 

85 label = _generate_topic_label(mems, idx) 

86 labeled[label] = mems 

87 

88 return labeled 

89 

90 

91def _generate_topic_label(memories: list[MemoryObject], index: int) -> str: 

92 """Generate a human-readable topic label from memory contents.""" 

93 # Simple TF-like keyword extraction 

94 word_freq: dict[str, int] = {} 

95 stopwords = { 

96 "the", 

97 "a", 

98 "an", 

99 "is", 

100 "are", 

101 "was", 

102 "were", 

103 "be", 

104 "been", 

105 "being", 

106 "have", 

107 "has", 

108 "had", 

109 "do", 

110 "does", 

111 "did", 

112 "will", 

113 "would", 

114 "could", 

115 "should", 

116 "may", 

117 "might", 

118 "must", 

119 "shall", 

120 "can", 

121 "need", 

122 "dare", 

123 "ought", 

124 "used", 

125 "to", 

126 "of", 

127 "in", 

128 "for", 

129 "on", 

130 "with", 

131 "at", 

132 "by", 

133 "from", 

134 "as", 

135 "into", 

136 "through", 

137 "during", 

138 "before", 

139 "after", 

140 "above", 

141 "below", 

142 "between", 

143 "under", 

144 "and", 

145 "but", 

146 "or", 

147 "yet", 

148 "so", 

149 "if", 

150 "because", 

151 "although", 

152 "though", 

153 "while", 

154 "where", 

155 "i", 

156 "you", 

157 "he", 

158 "she", 

159 "it", 

160 "we", 

161 "they", 

162 "me", 

163 "him", 

164 "her", 

165 "us", 

166 "them", 

167 "my", 

168 "your", 

169 "his", 

170 "its", 

171 "our", 

172 "their", 

173 "mine", 

174 "yours", 

175 "hers", 

176 "ours", 

177 "theirs", 

178 "this", 

179 "that", 

180 "these", 

181 "those", 

182 "am", 

183 } 

184 

185 for mem in memories: 

186 for word in mem.content.lower().split(): 

187 clean = word.strip(".,!?;:'\"()-") 

188 if len(clean) > 3 and clean not in stopwords: 

189 word_freq[clean] = word_freq.get(clean, 0) + 1 

190 

191 top_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:2] 

192 if top_words: 

193 label = " ".join(w[0].capitalize() for w in top_words) 

194 else: 

195 label = f"Topic {index + 1}" 

196 

197 return label