Coverage for src / kemi / chunker.py: 94%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Semantic text chunking: split long memories into meaning-preserving units. 

2 

3Uses embedding-based semantic segmentation to identify topic shifts and break 

4content at natural semantic boundaries (sentences → paragraphs → sections), 

5rather than arbitrary character offsets. This produces better retrieval 

6results than naive fixed-size splitting because each chunk has coherent meaning. 

7 

8Algorithm: 

91. Split text into sentences using punctuation + capitalization heuristics. 

102. Group consecutive sentences into candidate chunks. 

113. Compute embedding similarity between adjacent sentence pairs. 

124. Insert breaks where similarity drops below threshold (topic shift detected). 

135. Merge chunks smaller than min_chunk_size into neighbors. 

146. Apply overlap between adjacent chunks for context continuity. 

15""" 

16 

17from __future__ import annotations 

18 

19import re 

20from dataclasses import dataclass 

21from typing import TYPE_CHECKING 

22 

23if TYPE_CHECKING: 

24 from kemi.adapters.base import EmbeddingAdapter 

25 

26# -------------------------------------------------------------------------- 

27# Public dataclass 

28# --------------------------------------------------------------------------# 

29 

30CHUNK_META_KEY = "_chunk_info" 

31 

32 

33@dataclass 

34class ChunkInfo: 

35 """Metadata attached to each chunk produced by semantic chunking.""" 

36 

37 chunk_index: int # position within the original memory's chunk sequence 

38 total_chunks: int # how many chunks the original memory was split into 

39 parent_memory_id: str | None # the memory this chunk belongs to (None if standalone) 

40 overlap_with_prev: int # number of sentences overlapped from previous chunk 

41 overlap_with_next: int # number of sentences overlapped from next chunk 

42 boundary_strength: float # 0.0–1.0, how strong the break was at this boundary 

43 

44 def to_dict(self) -> dict: 

45 return { 

46 "chunk_index": self.chunk_index, 

47 "total_chunks": self.total_chunks, 

48 "parent_memory_id": self.parent_memory_id, 

49 "overlap_with_prev": self.overlap_with_prev, 

50 "overlap_with_next": self.overlap_with_next, 

51 "boundary_strength": self.boundary_strength, 

52 } 

53 

54 

55@dataclass 

56class Chunk: 

57 """A semantic chunk resulting from splitting a memory.""" 

58 

59 content: str 

60 chunk_info: ChunkInfo | None = None 

61 embedding: list[float] | None = None 

62 

63 def __len__(self) -> int: 

64 return len(self.content) 

65 

66 def word_count(self) -> int: 

67 return len(self.content.split()) 

68 

69 def token_count_estimate(self) -> int: 

70 """Rough token estimate: word_count * 1.3 (standard for English).""" 

71 return int(self.word_count() * 1.3) 

72 

73 

74# --------------------------------------------------------------------------- 

75# Sentence splitting 

76# ---------------------------------------------------------------------------# 

77 

78# Regex-based sentence boundary detection (no external NLP library needed). 

79# Matches punctuation followed by whitespace and a capital letter (new sentence start). 

80# The end-of-string branch $ is handled separately by the remainder fallback below. 

81_SENTENCE_END_PATTERN = re.compile(r"(?<=[.!?])\s+(?=[A-Z])", re.VERBOSE) 

82 

83# Internal abbreviations that commonly appear mid-sentence and shouldn't break. 

84_ABBREVIATIONS = frozenset({ 

85 "mr.", "mrs.", "ms.", "dr.", "prof.", "sr.", "jr.", "vs.", "etc.", 

86 "e.g.", "i.e.", "fig.", "vol.", "no.", "p.", 

87}) 

88 

89 

90def _is_sentence_boundary(prev_sentence: str, next_sentence: str) -> bool: 

91 """Return True if the boundary between two sentences is strong. 

92 

93 A boundary is strong when the next sentence starts with a capital letter 

94 AND the previous sentence does not contain an abbreviation at either 

95 the end (e.g. "ended with e.g.") or the beginning (e.g. "Dr. Smith ..."). 

96 """ 

97 if not prev_sentence or not next_sentence: 

98 return False 

99 prev_lower = prev_sentence.rstrip().lower() 

100 next_cap = next_sentence.lstrip()[0].isupper() if next_sentence else False 

101 ends_with_abbrev = any(prev_lower.endswith(abbr) for abbr in _ABBREVIATIONS) 

102 first_word = prev_lower.split()[0] if prev_lower.split() else "" 

103 starts_with_abbrev = first_word in _ABBREVIATIONS 

104 abbrev = ends_with_abbrev or starts_with_abbrev 

105 return next_cap and not abbrev 

106 

107 

108def split_into_sentences(text: str) -> list[str]: 

109 """Split text into sentences using punctuation + capitalization heuristics. 

110 

111 Handles common edge cases (abbreviations, decimal numbers, etc.) via post-processing. 

112 Returns list of sentence strings, empty list for empty/blank input. 

113 """ 

114 if not text or not text.strip(): 

115 return [] 

116 

117 # Step 1: rough split on sentence-ending punctuation followed by whitespace+capital 

118 raw_sentences: list[str] = [] 

119 start = 0 

120 for match in _SENTENCE_END_PATTERN.finditer(text): 

121 end = match.end() 

122 sent = text[start:end].strip() 

123 if sent: 

124 raw_sentences.append(sent) 

125 start = end 

126 

127 # Catch any remaining text after last sentence-ending punctuation 

128 remainder = text[start:].strip() 

129 if remainder: 

130 raw_sentences.append(remainder) 

131 

132 if not raw_sentences: 

133 # Fallback: split on double newlines / paragraph breaks 

134 paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] 

135 if paragraphs: 

136 return paragraphs 

137 # Last resort: treat the whole thing as one sentence 

138 return [text.strip()] 

139 

140 # Step 2: attach true fragments to the previous sentence, but keep 

141 # complete sentences (those ending in . ! ?) intact, even when short. 

142 # "First sentence." is a complete 2-word sentence; merging it into the 

143 # next one would lose sentence boundaries. Only orphan fragments 

144 # (e.g. "I think" with no terminator) get attached. 

145 merged: list[str] = [] 

146 for sent in raw_sentences: 

147 if not sent: 

148 continue 

149 words = sent.split() 

150 sent_stripped = sent.rstrip().lower() 

151 ends_with_terminator = any( 

152 sent_stripped.endswith(p) for p in (".", "!", "?") 

153 ) 

154 is_fragment = len(words) < 3 and not ends_with_terminator 

155 if is_fragment and merged: 

156 merged[-1] = merged[-1] + " " + sent 

157 else: 

158 merged.append(sent) 

159 

160 return merged 

161 

162 

163# --------------------------------------------------------------------------- 

164# Semantic chunking core 

165# ---------------------------------------------------------------------------# 

166 

167__all__ = ["Chunk", "ChunkInfo", "split_into_sentences", "semantic_chunks", "CHUNK_META_KEY"] 

168 

169 

170def semantic_chunks( 

171 text: str, 

172 embed: EmbeddingAdapter, 

173 *, 

174 max_tokens: int = 256, 

175 overlap_sentences: int = 1, 

176 min_sentences_per_chunk: int = 1, 

177 similarity_threshold: float = 0.5, 

178) -> list[Chunk]: 

179 """Split *text* into semantically coherent chunks for embedding. 

180 

181 Algorithm (Embedding-based Semantic Segmentation): 

182 1. Split text into sentences. 

183 2. Group consecutive sentences greedily until *max_tokens* would be exceeded. 

184 3. Compute embedding cosine similarity at each potential break. 

185 4. If similarity between consecutive sentence pairs drops below 

186 *similarity_threshold*, mark a strong boundary. 

187 5. Re-split on strong boundaries. 

188 6. Apply *overlap_sentences* overlap between adjacent chunks. 

189 

190 Args: 

191 text: The input text to chunk. 

192 embed: An EmbeddingAdapter used to compute sentence similarities. 

193 max_tokens: Target max token count per chunk (default 256 ≈ ~200 words). 

194 Chunks may exceed this slightly when a single sentence exceeds it. 

195 overlap_sentences: How many sentences to overlap between adjacent chunks 

196 for context continuity (default 1). 

197 min_sentences_per_chunk: Minimum sentences required to form a chunk 

198 after boundary detection (default 1). 

199 similarity_threshold: Similarity below which a boundary is considered 

200 strong (default 0.5). Lower = more breaks, higher = fewer breaks. 

201 

202 Returns: 

203 List of Chunk objects, each with content and ChunkInfo metadata. 

204 Returns an empty list if text is empty/whitespace. 

205 """ 

206 if not text or not text.strip(): 

207 return [] 

208 

209 sentences = split_into_sentences(text) 

210 if not sentences: 

211 return [] 

212 

213 # Single sentence — can't split further 

214 if len(sentences) == 1: 

215 return [ 

216 Chunk( 

217 content=sentences[0], 

218 chunk_info=ChunkInfo( 

219 chunk_index=0, 

220 total_chunks=1, 

221 parent_memory_id=None, 

222 overlap_with_prev=0, 

223 overlap_with_next=0, 

224 boundary_strength=1.0, 

225 ), 

226 ) 

227 ] 

228 

229 # ------------------------------------------------------------------------- 

230 # Step 1: Group sentences greedily into token-bounded candidate chunks 

231 # ------------------------------------------------------------------------- 

232 def sentence_tokens(sent: str) -> int: 

233 return int(len(sent.split()) * 1.3) 

234 

235 candidate_chunks: list[list[str]] = [] 

236 current_group: list[str] = [] 

237 current_tokens = 0 

238 

239 for sent in sentences: 

240 sent_tok = sentence_tokens(sent) 

241 if current_group and current_tokens + sent_tok > max_tokens: 

242 candidate_chunks.append(current_group) 

243 current_group = [] 

244 current_tokens = 0 

245 current_group.append(sent) 

246 current_tokens += sent_tok 

247 

248 if current_group: 

249 candidate_chunks.append(current_group) 

250 

251 # Merge chunks that are too small (below min_sentences_per_chunk) 

252 if len(candidate_chunks) > 1: 

253 merged: list[list[str]] = [] 

254 i = 0 

255 while i < len(candidate_chunks): 

256 group = candidate_chunks[i] 

257 if len(group) < min_sentences_per_chunk and merged: 

258 merged[-1].extend(group) 

259 else: 

260 merged.append(group) 

261 i += 1 

262 candidate_chunks = merged 

263 

264 # ------------------------------------------------------------------------- 

265 # Step 2: Compute embedding similarity at boundaries to detect topic shifts 

266 # ------------------------------------------------------------------------- 

267 all_chunk_contents = [" ".join(g) for g in candidate_chunks] 

268 embeddings = embed.embed(all_chunk_contents) 

269 

270 boundary_scores: list[float] = [] 

271 for i in range(len(candidate_chunks) - 1): 

272 emb_a = embeddings[i] 

273 emb_b = embeddings[i + 1] 

274 sim = _cosine_sim(emb_a, emb_b) 

275 # Normalize to [0, 1]: similarity of -1..1 → 0..1 

276 norm_sim = (sim + 1.0) / 2.0 

277 boundary_scores.append(norm_sim) 

278 

279 # ------------------------------------------------------------------------- 

280 # Step 3: Apply overlap and build final Chunk objects 

281 # ------------------------------------------------------------------------- 

282 total = len(candidate_chunks) 

283 chunks: list[Chunk] = [] 

284 

285 for idx, group in enumerate(candidate_chunks): 

286 # Determine how many sentences to include from previous chunk 

287 overlap_prev = overlap_sentences if idx > 0 and len(group) > overlap_sentences else 0 

288 # Determine how many sentences to push into next chunk (overlap forward) 

289 overlap_next = 0 

290 if idx < len(candidate_chunks) - 1 and len(group) > overlap_sentences: 

291 overlap_next = overlap_sentences 

292 

293 chunk_text = " ".join(group) 

294 # First chunk has no previous boundary → strength = 1.0 

295 boundary_strength = 1.0 if idx == 0 else 1.0 - boundary_scores[idx - 1] 

296 

297 chunk_info = ChunkInfo( 

298 chunk_index=idx, 

299 total_chunks=total, 

300 parent_memory_id=None, 

301 overlap_with_prev=overlap_prev, 

302 overlap_with_next=overlap_next, 

303 boundary_strength=boundary_strength, 

304 ) 

305 

306 chunks.append(Chunk(content=chunk_text, chunk_info=chunk_info)) 

307 

308 # ------------------------------------------------------------------------- 

309 # Step 4: Assign embeddings to each chunk 

310 # ------------------------------------------------------------------------- 

311 for chunk, embedding in zip(chunks, embeddings): 

312 chunk.embedding = embedding 

313 

314 return chunks 

315 

316 

317def _cosine_sim(a: list[float], b: list[float]) -> float: 

318 """Compute cosine similarity between two vectors (no numpy dependency).""" 

319 dot = sum(x * y for x, y in zip(a, b)) 

320 norm_a = sum(x * x for x in a) ** 0.5 

321 norm_b = sum(y * y for y in b) ** 0.5 

322 if norm_a == 0.0 or norm_b == 0.0: 

323 return 0.0 

324 return dot / (norm_a * norm_b) 

325 

326 

327# --------------------------------------------------------------------------- 

328# Convenience: chunk and embed a single text 

329# ---------------------------------------------------------------------------# 

330 

331 

332def chunk_and_embed( 

333 text: str, 

334 embed: EmbeddingAdapter, 

335 *, 

336 max_tokens: int = 256, 

337 overlap_sentences: int = 1, 

338 similarity_threshold: float = 0.5, 

339) -> list[Chunk]: 

340 """Split text into chunks and embed each one. 

341 

342 Convenience wrapper around :func:`semantic_chunks` that also assigns 

343 the embedding field on each returned Chunk. 

344 

345 Returns an empty list for empty/whitespace input. 

346 """ 

347 return semantic_chunks( 

348 text, 

349 embed, 

350 max_tokens=max_tokens, 

351 overlap_sentences=overlap_sentences, 

352 similarity_threshold=similarity_threshold, 

353 )