Coverage for src / dataknobs_xization / json / json_chunker.py: 55%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 16:15 -0700

1"""JSON chunker for generating RAG-optimized chunks from JSON data. 

2 

3This module provides functionality to chunk JSON data (objects, arrays, JSONL files) 

4into units suitable for RAG (Retrieval-Augmented Generation) applications, with 

5preserved metadata and configurable text generation. 

6 

7Supports both in-memory and streaming modes for handling large files. 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import re 

14from dataclasses import dataclass, field 

15from pathlib import Path 

16from typing import Any, Iterator, Literal 

17 

18# Patterns for detecting technical/non-text fields 

19UUID_PATTERN = re.compile( 

20 r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.IGNORECASE 

21) 

22BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/]{20,}={0,2}$") 

23TIMESTAMP_PATTERN = re.compile( 

24 r"^\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}" # ISO format 

25) 

26 

27# Field names commonly containing text content 

28TEXT_FIELD_NAMES = frozenset({ 

29 "title", "name", "description", "content", "text", "summary", 

30 "body", "message", "comment", "note", "notes", "abstract", 

31 "overview", "details", "explanation", "definition", "label", 

32}) 

33 

34# Field names to skip (technical/metadata) 

35SKIP_FIELD_NAMES = frozenset({ 

36 "id", "uuid", "guid", "_id", "created_at", "updated_at", 

37 "created", "updated", "timestamp", "modified", "hash", 

38 "checksum", "signature", "token", "key", "secret", 

39}) 

40 

41 

42@dataclass 

43class JSONChunkConfig: 

44 """Configuration for JSON chunking. 

45 

46 Attributes: 

47 max_chunk_size: Maximum size of generated text in characters 

48 text_template: Optional Jinja2 template for text generation (overrides auto-flatten) 

49 text_fields: Specific fields to use for text (None = auto-detect) 

50 nested_separator: Separator for flattened nested keys (e.g., "config.database.host") 

51 array_handling: How to handle arrays - expand into multiple chunks, join values, or take first 

52 include_field_names: Whether to include field names in generated text 

53 skip_technical_fields: Whether to skip UUIDs, timestamps, base64 in text generation 

54 """ 

55 

56 max_chunk_size: int = 1000 

57 text_template: str | None = None 

58 text_fields: list[str] | None = None 

59 nested_separator: str = "." 

60 array_handling: Literal["expand", "join", "first"] = "expand" 

61 include_field_names: bool = True 

62 skip_technical_fields: bool = True 

63 

64 

65@dataclass 

66class JSONChunk: 

67 """A chunk generated from JSON data. 

68 

69 Attributes: 

70 text: Generated embeddable text 

71 metadata: All original fields (flattened for nested objects) 

72 source_path: JSON path to this chunk's source (e.g., "[0].products[2]") 

73 source_file: Original file path (if from file) 

74 embedding_text: Enriched text optimized for embedding 

75 chunk_index: Index of this chunk in the sequence 

76 """ 

77 

78 text: str 

79 metadata: dict[str, Any] 

80 source_path: str = "" 

81 source_file: str = "" 

82 embedding_text: str = "" 

83 chunk_index: int = 0 

84 

85 def to_dict(self) -> dict[str, Any]: 

86 """Convert chunk to dictionary representation.""" 

87 return { 

88 "text": self.text, 

89 "metadata": self.metadata, 

90 "source_path": self.source_path, 

91 "source_file": self.source_file, 

92 "embedding_text": self.embedding_text, 

93 "chunk_index": self.chunk_index, 

94 } 

95 

96 

97class JSONChunker: 

98 """Chunker for generating chunks from JSON data with preserved metadata. 

99 

100 Supports both in-memory processing and streaming for large files. 

101 

102 Example: 

103 >>> chunker = JSONChunker() 

104 >>> # In-memory processing 

105 >>> chunks = chunker.chunk({"title": "Hello", "content": "World"}) 

106 >>> # Streaming from file 

107 >>> for chunk in chunker.stream_chunks("large_data.jsonl"): 

108 ... process(chunk) 

109 """ 

110 

111 def __init__(self, config: JSONChunkConfig | None = None): 

112 """Initialize the JSON chunker. 

113 

114 Args: 

115 config: Configuration for chunking behavior 

116 """ 

117 self.config = config or JSONChunkConfig() 

118 self._chunk_index = 0 

119 self._jinja_env: Any = None # Lazy loaded 

120 

121 def chunk( 

122 self, 

123 data: dict[str, Any] | list[dict[str, Any]], 

124 source: str = "", 

125 ) -> list[JSONChunk]: 

126 """Chunk in-memory JSON data. 

127 

128 Args: 

129 data: JSON object or array of objects to chunk 

130 source: Optional source identifier (e.g., file path) 

131 

132 Returns: 

133 List of JSONChunk objects 

134 """ 

135 self._chunk_index = 0 

136 

137 if isinstance(data, dict): 

138 return [self._process_item(data, source_path="", source_file=source)] 

139 elif isinstance(data, list): 

140 chunks = [] 

141 for idx, item in enumerate(data): 

142 if isinstance(item, dict): 

143 chunks.append( 

144 self._process_item(item, source_path=f"[{idx}]", source_file=source) 

145 ) 

146 return chunks 

147 else: 

148 raise ValueError(f"Expected dict or list, got {type(data).__name__}") 

149 

150 def stream_chunks( 

151 self, 

152 source: str | Path, 

153 timeout: int = 10, 

154 ) -> Iterator[JSONChunk]: 

155 """Stream chunks from large JSON files without loading into memory. 

156 

157 Leverages dataknobs_utils.json_utils streaming infrastructure. 

158 

159 Supported formats: 

160 - JSON arrays: Each top-level element becomes a chunk source 

161 - JSONL files: Each line is a complete JSON object 

162 - Compressed files: .gz files auto-detected and decompressed 

163 - URLs: Remote JSON fetched with streaming 

164 

165 Args: 

166 source: File path, URL, or JSON string 

167 timeout: Request timeout for URLs (seconds) 

168 

169 Yields: 

170 JSONChunk objects as they are processed 

171 """ 

172 source_str = str(source) 

173 self._chunk_index = 0 

174 

175 # Detect format and process accordingly 

176 if self._is_jsonl_file(source_str): 

177 yield from self._stream_jsonl(source_str) 

178 else: 

179 yield from self._stream_json_array(source_str, timeout) 

180 

181 def _is_jsonl_file(self, source: str) -> bool: 

182 """Check if source is a JSONL file based on extension.""" 

183 lower = source.lower() 

184 return ( 

185 lower.endswith(".jsonl") 

186 or lower.endswith(".jsonl.gz") 

187 or lower.endswith(".ndjson") 

188 or lower.endswith(".ndjson.gz") 

189 ) 

190 

191 def _stream_jsonl(self, source: str) -> Iterator[JSONChunk]: 

192 """Stream from a JSONL file (one JSON object per line).""" 

193 import gzip 

194 

195 source_path = Path(source) 

196 

197 # Handle gzip 

198 if source.lower().endswith(".gz"): 

199 opener = lambda p: gzip.open(p, "rt", encoding="utf-8") 

200 else: 

201 opener = lambda p: open(p, "r", encoding="utf-8") 

202 

203 with opener(source_path) as f: 

204 for line_num, line in enumerate(f): 

205 line = line.strip() 

206 if not line: 

207 continue 

208 try: 

209 item = json.loads(line) 

210 if isinstance(item, dict): 

211 yield self._process_item( 

212 item, 

213 source_path=f"[{line_num}]", 

214 source_file=source, 

215 ) 

216 except json.JSONDecodeError: 

217 continue # Skip malformed lines 

218 

219 def _stream_json_array(self, source: str, timeout: int) -> Iterator[JSONChunk]: 

220 """Stream from a JSON array file using json_utils infrastructure.""" 

221 try: 

222 from dataknobs_utils.json_utils import ( 

223 stream_json_data, 

224 PathSorter, 

225 ArrayElementAcceptStrategy, 

226 Path as JsonPath, 

227 build_jq_path, 

228 ) 

229 except ImportError: 

230 # Fall back to loading entire file if streaming utils not available 

231 yield from self._fallback_load(source) 

232 return 

233 

234 # Use PathSorter to group paths into records 

235 sorter = PathSorter( 

236 ArrayElementAcceptStrategy(max_array_level=0), 

237 max_groups=2, 

238 ) 

239 

240 item_num = 0 

241 

242 def visitor(item: Any, path: tuple[Any, ...]) -> None: 

243 nonlocal item_num 

244 jq_path = build_jq_path(path, keep_list_idxs=True) 

245 sorter.add_path(JsonPath(jq_path, item, line_num=item_num)) 

246 item_num += 1 

247 

248 stream_json_data(source, visitor, timeout=timeout) 

249 

250 # Process collected groups 

251 if sorter.groups: 

252 for group in sorter.groups: 

253 sorter.close_group(check_size=False) 

254 record_dict = group.as_dict() 

255 # Handle array at root level 

256 if isinstance(record_dict, dict) and len(record_dict) == 1: 

257 root_key = next(iter(record_dict.keys())) 

258 items = record_dict[root_key] 

259 if isinstance(items, list): 

260 for idx, item in enumerate(items): 

261 if isinstance(item, dict): 

262 yield self._process_item( 

263 item, 

264 source_path=f".{root_key}[{idx}]", 

265 source_file=source, 

266 ) 

267 

268 def _fallback_load(self, source: str) -> Iterator[JSONChunk]: 

269 """Fallback: load entire file when streaming utils unavailable.""" 

270 import gzip 

271 from pathlib import Path 

272 

273 source_path = Path(source) 

274 if not source_path.exists(): 

275 return 

276 

277 if source.lower().endswith(".gz"): 

278 with gzip.open(source_path, "rt", encoding="utf-8") as f: 

279 data = json.load(f) 

280 else: 

281 with open(source_path, "r", encoding="utf-8") as f: 

282 data = json.load(f) 

283 

284 yield from self.chunk(data, source=source) 

285 

286 def _process_item( 

287 self, 

288 item: dict[str, Any], 

289 source_path: str, 

290 source_file: str, 

291 ) -> JSONChunk: 

292 """Process a single JSON object into a chunk. 

293 

294 Args: 

295 item: JSON object to process 

296 source_path: JSON path to this item 

297 source_file: Source file path 

298 

299 Returns: 

300 JSONChunk with generated text and preserved metadata 

301 """ 

302 # Flatten nested structure for metadata 

303 flat_metadata = self._flatten(item) 

304 

305 # Generate text 

306 if self.config.text_template: 

307 text = self._render_template(item) 

308 else: 

309 text = self._auto_generate_text(item) 

310 

311 # Truncate if needed 

312 if len(text) > self.config.max_chunk_size: 

313 text = text[: self.config.max_chunk_size - 3] + "..." 

314 

315 # Generate embedding text (enriched with context) 

316 embedding_text = self._build_embedding_text(item, text) 

317 

318 chunk = JSONChunk( 

319 text=text, 

320 metadata=flat_metadata, 

321 source_path=source_path, 

322 source_file=source_file, 

323 embedding_text=embedding_text, 

324 chunk_index=self._chunk_index, 

325 ) 

326 self._chunk_index += 1 

327 return chunk 

328 

329 def _flatten( 

330 self, 

331 obj: dict[str, Any], 

332 prefix: str = "", 

333 ) -> dict[str, Any]: 

334 """Flatten nested dict/list structure using dot notation. 

335 

336 Args: 

337 obj: Object to flatten 

338 prefix: Current key prefix 

339 

340 Returns: 

341 Flattened dictionary 

342 """ 

343 result: dict[str, Any] = {} 

344 sep = self.config.nested_separator 

345 

346 for key, value in obj.items(): 

347 full_key = f"{prefix}{sep}{key}" if prefix else key 

348 

349 if isinstance(value, dict): 

350 result.update(self._flatten(value, full_key)) 

351 elif isinstance(value, list): 

352 if value and isinstance(value[0], dict): 

353 # List of objects - store count and flatten first 

354 result[f"{full_key}._count"] = len(value) 

355 if value: 

356 result.update(self._flatten(value[0], f"{full_key}[0]")) 

357 else: 

358 # List of primitives - store as-is 

359 result[full_key] = value 

360 else: 

361 result[full_key] = value 

362 

363 return result 

364 

365 def _auto_generate_text(self, item: dict[str, Any]) -> str: 

366 """Auto-generate embeddable text from JSON object. 

367 

368 Algorithm: 

369 1. Extract title/name/id field as primary identifier 

370 2. Concatenate text-like fields (description, content, summary) 

371 3. Format nested objects with field names 

372 4. Handle arrays based on config 

373 

374 Args: 

375 item: JSON object to convert to text 

376 

377 Returns: 

378 Generated text string 

379 """ 

380 parts: list[str] = [] 

381 

382 # Use specific fields if configured 

383 if self.config.text_fields: 

384 for field_name in self.config.text_fields: 

385 value = self._get_nested_value(item, field_name) 

386 if value is not None: 

387 parts.append(self._format_value(field_name, value)) 

388 return "\n".join(parts) 

389 

390 # Auto-detect: prioritize known text fields 

391 used_keys: set[str] = set() 

392 

393 # First pass: extract primary identifier 

394 for key in ["title", "name", "label"]: 

395 if key in item and isinstance(item[key], str): 

396 parts.append(item[key]) 

397 used_keys.add(key) 

398 break 

399 

400 # Second pass: extract text content fields 

401 for key in item: 

402 if key in used_keys: 

403 continue 

404 lower_key = key.lower() 

405 if lower_key in TEXT_FIELD_NAMES: 

406 value = item[key] 

407 if isinstance(value, str) and value.strip(): 

408 if not self._is_technical_value(value): 

409 if self.config.include_field_names and key not in ("content", "text", "body"): 

410 parts.append(f"{key}: {value}") 

411 else: 

412 parts.append(value) 

413 used_keys.add(key) 

414 

415 # Third pass: include other non-technical fields 

416 for key, value in item.items(): 

417 if key in used_keys: 

418 continue 

419 lower_key = key.lower() 

420 if lower_key in SKIP_FIELD_NAMES: 

421 continue 

422 if key.startswith("_"): 

423 continue 

424 

425 formatted = self._format_value(key, value) 

426 if formatted: 

427 parts.append(formatted) 

428 

429 return "\n".join(parts) 

430 

431 def _format_value(self, key: str, value: Any, depth: int = 0) -> str: 

432 """Format a value for text generation. 

433 

434 Args: 

435 key: Field name 

436 value: Field value 

437 depth: Nesting depth (for indentation) 

438 

439 Returns: 

440 Formatted string 

441 """ 

442 if value is None: 

443 return "" 

444 

445 if isinstance(value, str): 

446 if self.config.skip_technical_fields and self._is_technical_value(value): 

447 return "" 

448 if self.config.include_field_names: 

449 return f"{key}: {value}" 

450 return value 

451 

452 if isinstance(value, bool): 

453 if self.config.include_field_names: 

454 return f"{key}: {'yes' if value else 'no'}" 

455 return "yes" if value else "no" 

456 

457 if isinstance(value, (int, float)): 

458 if self.config.include_field_names: 

459 return f"{key}: {value}" 

460 return str(value) 

461 

462 if isinstance(value, list): 

463 if not value: 

464 return "" 

465 if isinstance(value[0], dict): 

466 # List of objects - summarize 

467 return f"{key}: {len(value)} items" 

468 # List of primitives 

469 if self.config.array_handling == "join": 

470 joined = ", ".join(str(v) for v in value[:10]) 

471 if len(value) > 10: 

472 joined += f"... ({len(value)} total)" 

473 if self.config.include_field_names: 

474 return f"{key}: {joined}" 

475 return joined 

476 elif self.config.array_handling == "first": 

477 return self._format_value(key, value[0], depth) 

478 # "expand" - return all items 

479 items = [str(v) for v in value] 

480 if self.config.include_field_names: 

481 return f"{key}: {', '.join(items)}" 

482 return ", ".join(items) 

483 

484 if isinstance(value, dict): 

485 # Nested object - format recursively 

486 sub_parts = [] 

487 for k, v in value.items(): 

488 formatted = self._format_value(k, v, depth + 1) 

489 if formatted: 

490 sub_parts.append(formatted) 

491 if sub_parts: 

492 if self.config.include_field_names: 

493 return f"{key}: {'; '.join(sub_parts)}" 

494 return "; ".join(sub_parts) 

495 return "" 

496 

497 return "" 

498 

499 def _is_technical_value(self, value: str) -> bool: 

500 """Check if a string value appears to be technical/non-text.""" 

501 if not self.config.skip_technical_fields: 

502 return False 

503 

504 if len(value) < 10: 

505 return False 

506 

507 if UUID_PATTERN.match(value): 

508 return True 

509 if BASE64_PATTERN.match(value) and len(value) > 50: 

510 return True 

511 if TIMESTAMP_PATTERN.match(value): 

512 return True 

513 

514 return False 

515 

516 def _get_nested_value(self, obj: dict[str, Any], path: str) -> Any: 

517 """Get a value from a nested dict using dot notation path. 

518 

519 Args: 

520 obj: Object to traverse 

521 path: Dot-notation path (e.g., "config.database.host") 

522 

523 Returns: 

524 Value at path, or None if not found 

525 """ 

526 parts = path.split(self.config.nested_separator) 

527 current: Any = obj 

528 

529 for part in parts: 

530 if isinstance(current, dict) and part in current: 

531 current = current[part] 

532 else: 

533 return None 

534 

535 return current 

536 

537 def _render_template(self, item: dict[str, Any]) -> str: 

538 """Render text using Jinja2 template. 

539 

540 Args: 

541 item: JSON object to render 

542 

543 Returns: 

544 Rendered text string 

545 """ 

546 if self._jinja_env is None: 

547 try: 

548 from jinja2 import Environment 

549 self._jinja_env = Environment() 

550 except ImportError: 

551 raise ImportError( 

552 "jinja2 is required for template-based text generation. " 

553 "Install it with: pip install jinja2" 

554 ) 

555 

556 template = self._jinja_env.from_string(self.config.text_template) 

557 return template.render(**item) 

558 

559 def _build_embedding_text(self, item: dict[str, Any], base_text: str) -> str: 

560 """Build enriched text optimized for embedding. 

561 

562 Adds context that improves semantic search quality. 

563 

564 Args: 

565 item: Original JSON object 

566 base_text: Generated base text 

567 

568 Returns: 

569 Enriched text for embedding 

570 """ 

571 parts = [] 

572 

573 # Add type/category context if available 

574 for key in ["type", "category", "kind", "class"]: 

575 if key in item and isinstance(item[key], str): 

576 parts.append(f"[{item[key].upper()}]") 

577 break 

578 

579 parts.append(base_text) 

580 

581 # Add tags/keywords if available 

582 for key in ["tags", "keywords", "labels"]: 

583 if key in item and isinstance(item[key], list): 

584 tags = [str(t) for t in item[key][:5] if isinstance(t, str)] 

585 if tags: 

586 parts.append(f"Tags: {', '.join(tags)}") 

587 break 

588 

589 return " ".join(parts)