Coverage for session_buddy / tools / knowledge_graph_tools.py: 37.86%

268 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Knowledge Graph MCP tools for semantic memory management. 

3 

4This module provides MCP tools for interacting with the DuckPGQ-based knowledge graph, 

5enabling semantic memory through entity-relationship modeling. 

6 

7Refactored to use utility modules for reduced code duplication. 

8""" 

9 

10from __future__ import annotations 

11 

12import re 

13from typing import TYPE_CHECKING, Any 

14 

15from session_buddy.utils.error_handlers import _get_logger 

16from session_buddy.utils.messages import ToolMessages 

17 

18if TYPE_CHECKING: 

19 from collections.abc import Awaitable, Callable 

20 

21 from session_buddy.adapters.knowledge_graph_adapter import ( 

22 KnowledgeGraphDatabaseAdapter as KnowledgeGraphDatabase, 

23 ) 

24 

25 

26# ============================================================================ 

27# Service Resolution 

28# ============================================================================ 

29 

30 

31def _check_knowledge_graph_available() -> bool: 

32 """Check if knowledge graph dependencies are available.""" 

33 try: 

34 import duckdb 

35 

36 return True 

37 except ImportError: 

38 return False 

39 

40 

41async def _require_knowledge_graph() -> KnowledgeGraphDatabase: 

42 """Get knowledge graph database instance or raise error.""" 

43 try: 

44 from session_buddy.adapters.knowledge_graph_adapter import ( 

45 KnowledgeGraphDatabaseAdapter, 

46 ) 

47 from session_buddy.di import configure 

48 

49 configure() 

50 kg = KnowledgeGraphDatabaseAdapter() 

51 await kg.initialize() 

52 return kg 

53 except Exception as e: 

54 msg = f"Knowledge graph not available: {e}" 

55 raise RuntimeError(msg) from e 

56 

57 

58async def _execute_kg_operation( 

59 operation_name: str, operation: Callable[[Any], Awaitable[str]] 

60) -> str: 

61 """Execute a knowledge graph operation with error handling.""" 

62 try: 

63 async with await _require_knowledge_graph() as kg: 

64 return await operation(kg) 

65 except RuntimeError as e: 

66 return f"{e!s}. Install dependencies: uv sync" 

67 except Exception as e: 

68 _get_logger().exception(f"Error in {operation_name}: {e}") 

69 return ToolMessages.operation_failed(operation_name, e) 

70 

71 

72# ============================================================================ 

73# Entity Extraction Patterns 

74# ============================================================================ 

75 

76 

77ENTITY_PATTERNS = { 

78 "project": r"\b([A-Z][a-z]+-[a-z]+(?:-[a-z]+)*)\b", # kebab-case projects 

79 "library": r"\b(ACB|FastMCP|DuckDB|pytest|pydantic|uvicorn)\b", 

80 "technology": r"\b(Python|JavaScript|TypeScript|Docker|Kubernetes)\b", 

81 "concept": r"\b(dependency injection|semantic memory|property graph|vector search)\b", 

82} 

83 

84 

85# ============================================================================ 

86# Entity Operations 

87# ============================================================================ 

88 

89 

90async def _create_entity_operation( 

91 kg: Any, 

92 name: str, 

93 entity_type: str, 

94 observations: list[str], 

95 properties: dict[str, Any], 

96) -> str: 

97 """Create an entity in the knowledge graph.""" 

98 entity = await kg.create_entity( 

99 name=name, 

100 entity_type=entity_type, 

101 observations=observations, 

102 properties=properties, 

103 ) 

104 

105 lines = [ 

106 f"✅ Entity '{name}' created successfully!", 

107 f"📊 Type: {entity_type}", 

108 f"🆔 ID: {entity['id']}", 

109 ] 

110 

111 if observations: 

112 lines.append(f"📝 Observations: {len(observations)}") 

113 if properties: 

114 lines.append(f"⚙️ Properties: {', '.join(properties.keys())}") 

115 

116 _get_logger().info( 

117 "Entity created", 

118 entity_name=name, 

119 entity_type=entity_type, 

120 observations_count=len(observations), 

121 ) 

122 return "\n".join(lines) 

123 

124 

125async def _create_entity_impl( 

126 name: str, 

127 entity_type: str, 

128 observations: list[str] | None = None, 

129 properties: dict[str, Any] | None = None, 

130) -> str: 

131 """Create an entity in the knowledge graph.""" 

132 

133 async def operation_wrapper(kg: Any) -> str: 

134 return await _create_entity_operation( 

135 kg, name, entity_type, observations or [], properties or {} 

136 ) 

137 

138 return await _execute_kg_operation( 

139 "Create entity", 

140 operation_wrapper, 

141 ) 

142 

143 

144async def _add_observation_operation( 

145 kg: Any, entity_name: str, observation: str 

146) -> str: 

147 """Add an observation (fact) to an existing entity.""" 

148 success = await kg.add_observation(entity_name, observation) 

149 

150 if not success: 

151 return f"❌ Entity '{entity_name}' not found" 

152 

153 _get_logger().info( 

154 "Observation added", 

155 entity_name=entity_name, 

156 observation=observation[:100], 

157 ) 

158 return "\n".join( 

159 [ 

160 f"✅ Observation added to '{entity_name}'", 

161 f"📝 Observation: {observation}", 

162 ] 

163 ) 

164 

165 

166async def _add_observation_impl(entity_name: str, observation: str) -> str: 

167 """Add an observation (fact) to an existing entity.""" 

168 

169 async def operation_wrapper(kg: Any) -> str: 

170 return await _add_observation_operation(kg, entity_name, observation) 

171 

172 return await _execute_kg_operation( 

173 "Add observation", 

174 operation_wrapper, 

175 ) 

176 

177 

178# ============================================================================ 

179# Relationship Operations 

180# ============================================================================ 

181 

182 

183async def _create_relation_operation( 

184 kg: Any, 

185 from_entity: str, 

186 to_entity: str, 

187 relation_type: str, 

188 properties: dict[str, Any], 

189) -> str: 

190 """Create a relationship between two entities.""" 

191 relation = await kg.create_relation( 

192 from_entity=from_entity, 

193 to_entity=to_entity, 

194 relation_type=relation_type, 

195 properties=properties, 

196 ) 

197 

198 if not relation: 

199 return f"❌ One or both entities not found: {from_entity}, {to_entity}" 

200 

201 lines = [ 

202 f"✅ Relationship created: {from_entity} --[{relation_type}]--> {to_entity}", 

203 f"🆔 Relation ID: {relation['id']}", 

204 ] 

205 

206 if properties: 

207 lines.append(f"⚙️ Properties: {', '.join(properties.keys())}") 

208 

209 _get_logger().info( 

210 "Relation created", 

211 from_entity=from_entity, 

212 to_entity=to_entity, 

213 relation_type=relation_type, 

214 ) 

215 return "\n".join(lines) 

216 

217 

218async def _create_relation_impl( 

219 from_entity: str, 

220 to_entity: str, 

221 relation_type: str, 

222 properties: dict[str, Any] | None = None, 

223) -> str: 

224 """Create a relationship between two entities.""" 

225 

226 async def operation_wrapper(kg: Any) -> str: 

227 return await _create_relation_operation( 

228 kg, from_entity, to_entity, relation_type, properties or {} 

229 ) 

230 

231 return await _execute_kg_operation( 

232 "Create relation", 

233 operation_wrapper, 

234 ) 

235 

236 

237# ============================================================================ 

238# Search Operations 

239# ============================================================================ 

240 

241 

242def _format_entity_result(entity: dict[str, Any]) -> list[str]: 

243 """Format a single entity search result.""" 

244 lines = [f"📌 {entity['name']} ({entity['entity_type']})"] 

245 

246 observations = entity.get("observations") 

247 if observations: 

248 lines.append(f" 📝 Observations: {len(observations)}") 

249 if observations: 249 ↛ 253line 249 didn't jump to line 253 because the condition on line 249 was always true

250 preview = observations[0] 

251 lines.append(f" └─ {preview[:80]}{'...' if len(preview) > 80 else ''}") 

252 

253 lines.append("") 

254 return lines 

255 

256 

257def _format_batch_results( 

258 created: list[str], 

259 failed: list[tuple[str, str]], 

260) -> list[str]: 

261 """Format batch entity creation results.""" 

262 lines = [ 

263 "📦 Batch Entity Creation Results", 

264 "", 

265 f"Successfully Created: {len(created)}", 

266 ] 

267 

268 if created: 

269 max_show = 10 

270 for name in created[:max_show]: 

271 lines.append(f"{name}") 

272 remaining = len(created) - max_show 

273 if remaining > 0: 

274 lines.append(f" • and {remaining} more") 

275 

276 if failed: 

277 lines.append("") 

278 lines.append(f"Failed: {len(failed)}") 

279 max_failed = 5 

280 for name, error in failed[:max_failed]: 

281 lines.append(f"{name}: {error}") 

282 remaining_failed = len(failed) - max_failed 

283 if remaining_failed > 0: 

284 lines.append(f" • and {remaining_failed} more") 

285 

286 return lines 

287 

288 

289async def _search_entities_operation( 

290 kg: Any, query: str, entity_type: str | None, limit: int 

291) -> str: 

292 """Search for entities by name or observations.""" 

293 results = await kg.search_entities( 

294 query=query, 

295 entity_type=entity_type, 

296 limit=limit, 

297 ) 

298 

299 if not results: 

300 return f"🔍 No entities found matching '{query}'" 

301 

302 lines = [f"🔍 Found {len(results)} entities matching '{query}':", ""] 

303 

304 for entity in results: 

305 lines.extend(_format_entity_result(entity)) 

306 

307 _get_logger().info( 

308 "Entities searched", 

309 query=query, 

310 entity_type=entity_type, 

311 results_count=len(results), 

312 ) 

313 return "\n".join(lines) 

314 

315 

316async def _search_entities_impl( 

317 query: str, 

318 entity_type: str | None = None, 

319 limit: int = 10, 

320) -> str: 

321 """Search for entities by name or observations.""" 

322 

323 async def operation_wrapper(kg: Any) -> str: 

324 return await _search_entities_operation(kg, query, entity_type, limit) 

325 

326 return await _execute_kg_operation( 

327 "Search entities", 

328 operation_wrapper, 

329 ) 

330 

331 

332def _format_relationship(rel: dict[str, Any], direction: str, entity_name: str) -> str: 

333 """Format a single relationship based on direction.""" 

334 if direction == "outgoing" or ( 

335 direction == "both" and rel["from_entity"] == entity_name 

336 ): 

337 return ( 

338 f" {rel['from_entity']} --[{rel['relation_type']}]--> {rel['to_entity']}" 

339 ) 

340 return f" {rel['from_entity']} <--[{rel['relation_type']}]-- {rel['to_entity']}" 

341 

342 

343async def _get_entity_relationships_operation( 

344 kg: Any, entity_name: str, relation_type: str | None, direction: str 

345) -> str: 

346 """Get all relationships for an entity.""" 

347 relationships = await kg.get_relationships( 

348 entity_name=entity_name, 

349 relation_type=relation_type, 

350 direction=direction, 

351 ) 

352 

353 if not relationships: 

354 return f"🔍 No relationships found for '{entity_name}'" 

355 

356 lines = [f"🔗 Found {len(relationships)} relationships for '{entity_name}':", ""] 

357 

358 for rel in relationships: 

359 lines.append(_format_relationship(rel, direction, entity_name)) 

360 

361 _get_logger().info( 

362 "Relationships retrieved", 

363 entity_name=entity_name, 

364 relation_type=relation_type, 

365 direction=direction, 

366 count=len(relationships), 

367 ) 

368 return "\n".join(lines) 

369 

370 

371async def _get_entity_relationships_impl( 

372 entity_name: str, 

373 relation_type: str | None = None, 

374 direction: str = "both", 

375) -> str: 

376 """Get all relationships for an entity.""" 

377 

378 async def operation_wrapper(kg: Any) -> str: 

379 return await _get_entity_relationships_operation( 

380 kg, entity_name, relation_type, direction 

381 ) 

382 

383 return await _execute_kg_operation( 

384 "Get entity relationships", 

385 operation_wrapper, 

386 ) 

387 

388 

389# ============================================================================ 

390# Path Finding 

391# ============================================================================ 

392 

393 

394async def _find_path_operation( 

395 kg: Any, from_entity: str, to_entity: str, max_depth: int 

396) -> str: 

397 """Find paths between two entities using SQL/PGQ.""" 

398 paths = await kg.find_path( 

399 from_entity=from_entity, 

400 to_entity=to_entity, 

401 max_depth=max_depth, 

402 ) 

403 

404 if not paths: 

405 return f"🔍 No path found between '{from_entity}' and '{to_entity}'" 

406 

407 lines = [ 

408 f"🛤️ Found {len(paths)} path(s) from '{from_entity}' to '{to_entity}':", 

409 "", 

410 ] 

411 

412 for i, path in enumerate(paths, 1): 

413 lines.extend( 

414 [ 

415 f"{i}. Path length: {path['path_length']} hop(s)", 

416 f" {path['from_entity']} ➜ ... ➜ {path['to_entity']}", 

417 "", 

418 ] 

419 ) 

420 

421 _get_logger().info( 

422 "Paths found", 

423 from_entity=from_entity, 

424 to_entity=to_entity, 

425 paths_count=len(paths), 

426 ) 

427 return "\n".join(lines) 

428 

429 

430async def _find_path_impl( 

431 from_entity: str, 

432 to_entity: str, 

433 max_depth: int = 5, 

434) -> str: 

435 """Find paths between two entities using SQL/PGQ.""" 

436 

437 async def operation_wrapper(kg: Any) -> str: 

438 return await _find_path_operation(kg, from_entity, to_entity, max_depth) 

439 

440 return await _execute_kg_operation( 

441 "Find path", 

442 operation_wrapper, 

443 ) 

444 

445 

446# ============================================================================ 

447# Statistics 

448# ============================================================================ 

449 

450 

451def _format_entity_types(entity_types: dict[str, int]) -> list[str]: 

452 """Format entity type counts for statistics output.""" 

453 if not entity_types: 

454 return [] 

455 

456 lines = ["📊 Entity Types:"] 

457 lines.extend(f"{etype}: {count}" for etype, count in entity_types.items()) 

458 lines.append("") 

459 return lines 

460 

461 

462def _format_relationship_types(relationship_types: dict[str, int]) -> list[str]: 

463 """Format relationship type counts for statistics output.""" 

464 if not relationship_types: 

465 return [] 

466 

467 lines = ["🔗 Relationship Types:"] 

468 lines.extend( 

469 f"{rtype}: {count}" for rtype, count in relationship_types.items() 

470 ) 

471 lines.append("") 

472 return lines 

473 

474 

475async def _get_knowledge_graph_stats_operation(kg: Any) -> str: 

476 """Get knowledge graph statistics.""" 

477 stats = await kg.get_stats() 

478 

479 lines = [ 

480 "📊 Knowledge Graph Statistics", 

481 "", 

482 f"📌 Total Entities: {stats['total_entities']}", 

483 f"🔗 Total Relationships: {stats['total_relationships']}", 

484 "", 

485 ] 

486 

487 # Entity types 

488 entity_types = stats.get("entity_types", {}) 

489 lines.extend(_format_entity_types(entity_types)) 

490 

491 # Relationship types 

492 relationship_types = stats.get("relationship_types", {}) 

493 lines.extend(_format_relationship_types(relationship_types)) 

494 

495 lines.extend( 

496 [ 

497 f"💾 Database: {stats['database_path']}", 

498 f"🔧 DuckPGQ: {'✅ Installed' if stats['duckpgq_installed'] else '❌ Not installed'}", 

499 ] 

500 ) 

501 

502 _get_logger().info("Knowledge graph stats retrieved", **stats) 

503 return "\n".join(lines) 

504 

505 

506async def _get_knowledge_graph_stats_impl() -> str: 

507 """Get knowledge graph statistics.""" 

508 return await _execute_kg_operation( 

509 "Get KG stats", _get_knowledge_graph_stats_operation 

510 ) 

511 

512 

513# ============================================================================ 

514# Entity Extraction 

515# ============================================================================ 

516 

517 

518def _extract_patterns_from_context(context: str) -> dict[str, set[str]]: 

519 """Extract entity patterns from context text.""" 

520 extracted: dict[str, set[str]] = {} 

521 for entity_type, pattern in ENTITY_PATTERNS.items(): 

522 matches = re.findall(pattern, context, re.IGNORECASE) 

523 if matches: 

524 extracted[entity_type] = set(matches) 

525 return extracted 

526 

527 

528async def _auto_create_entity_if_new( 

529 kg: Any, entity_name: str, entity_type: str 

530) -> bool: 

531 """Create entity if it doesn't exist. Returns True if created.""" 

532 existing = await kg.find_entity_by_name(entity_name) 

533 if not existing: 

534 await kg.create_entity( 

535 name=entity_name, 

536 entity_type=entity_type, 

537 observations=["Extracted from conversation context"], 

538 ) 

539 return True 

540 return False 

541 

542 

543async def _process_entity_type( 

544 kg: Any, 

545 entity_type: str, 

546 entities: set[str], 

547 auto_create: bool, 

548) -> tuple[list[str], int, int]: 

549 """Process entities of a specific type.""" 

550 lines = [f"📊 {entity_type.capitalize()}:"] 

551 count = 0 

552 created = 0 

553 

554 for entity_name in sorted(entities): 

555 lines.append(f"{entity_name}") 

556 count += 1 

557 if auto_create and await _auto_create_entity_if_new( 

558 kg, entity_name, entity_type 

559 ): 

560 created += 1 

561 

562 lines.append("") 

563 return lines, count, created 

564 

565 

566async def _extract_entities_from_context_impl( 

567 context: str, 

568 auto_create: bool = False, 

569) -> str: 

570 """Extract entities from conversation context using pattern matching.""" 

571 

572 async def operation(kg: Any) -> str: 

573 extracted = _extract_patterns_from_context(context) 

574 if not extracted: 

575 return "🔍 No entities detected in context" 

576 

577 lines = ["🔍 Extracted Entities from Context:", ""] 

578 total_extracted = 0 

579 created_count = 0 

580 

581 for entity_type, entities in extracted.items(): 

582 type_lines, count, created = await _process_entity_type( 

583 kg, entity_type, entities, auto_create 

584 ) 

585 lines.extend(type_lines) 

586 total_extracted += count 

587 created_count += created 

588 

589 lines.append(f"📊 Total Extracted: {total_extracted}") 

590 if auto_create: 

591 lines.append(f"✅ Auto-created: {created_count} new entities") 

592 

593 _get_logger().info( 

594 "Entities extracted from context", 

595 total_extracted=total_extracted, 

596 auto_created=created_count if auto_create else 0, 

597 ) 

598 return "\n".join(lines) 

599 

600 return await _execute_kg_operation("Extract entities from context", operation) 

601 

602 

603# ============================================================================ 

604# Batch Operations 

605# ============================================================================ 

606 

607 

608async def _create_single_entity( 

609 kg: Any, entity_data: dict[str, Any] 

610) -> tuple[str | None, tuple[str, str] | None]: 

611 """Create a single entity. Returns (created_name, None) or (None, (name, error)).""" 

612 try: 

613 entity = await kg.create_entity( 

614 name=entity_data["name"], 

615 entity_type=entity_data["entity_type"], 

616 observations=entity_data.get("observations", []), 

617 properties=entity_data.get("properties", {}), 

618 ) 

619 return entity["name"], None 

620 except Exception as e: 

621 return None, (entity_data["name"], str(e)) 

622 

623 

624async def _batch_create_entities_operation( 

625 kg: Any, entities: list[dict[str, Any]] 

626) -> str: 

627 """Bulk create multiple entities.""" 

628 created = [] 

629 failed = [] 

630 

631 for entity_data in entities: 

632 created_name, failure = await _create_single_entity(kg, entity_data) 

633 if created_name: 

634 created.append(created_name) 

635 elif failure: 

636 failed.append(failure) 

637 

638 lines = [ 

639 "📦 Batch Entity Creation Results:", 

640 "", 

641 f"✅ Successfully Created: {len(created)}", 

642 ] 

643 

644 if created: 

645 for name in created[:10]: # Show first 10 

646 lines.append(f"{name}") 

647 if len(created) > 10: 

648 lines.append(f" ... and {len(created) - 10} more") 

649 lines.append("") 

650 

651 if failed: 

652 lines.append(f"❌ Failed: {len(failed)}") 

653 for name, error in failed[:5]: # Show first 5 failures 

654 lines.append(f"{name}: {error}") 

655 if len(failed) > 5: 

656 lines.append(f" ... and {len(failed) - 5} more") 

657 

658 _get_logger().info( 

659 "Batch entities created", 

660 total=len(entities), 

661 created=len(created), 

662 failed=len(failed), 

663 ) 

664 return "\n".join(lines) 

665 

666 

667async def _batch_create_entities_impl(entities: list[dict[str, Any]]) -> str: 

668 """Bulk create multiple entities.""" 

669 

670 async def operation_wrapper(kg: Any) -> str: 

671 return await _batch_create_entities_operation(kg, entities) 

672 

673 return await _execute_kg_operation( 

674 "Batch create entities", 

675 operation_wrapper, 

676 ) 

677 

678 

679# ============================================================================ 

680# MCP Tool Registration 

681# ============================================================================ 

682 

683 

684def register_knowledge_graph_tools(mcp_server: Any) -> None: 

685 """Register all knowledge graph MCP tools with the server.""" 

686 

687 @mcp_server.tool() # type: ignore[misc] 

688 async def create_entity( 

689 name: str, 

690 entity_type: str, 

691 observations: list[str] | None = None, 

692 properties: dict[str, Any] | None = None, 

693 ) -> str: 

694 """Create an entity (node) in the knowledge graph.""" 

695 return await _create_entity_impl(name, entity_type, observations, properties) 

696 

697 @mcp_server.tool() # type: ignore[misc] 

698 async def add_observation(entity_name: str, observation: str) -> str: 

699 """Add an observation (fact) to an existing entity.""" 

700 return await _add_observation_impl(entity_name, observation) 

701 

702 @mcp_server.tool() # type: ignore[misc] 

703 async def create_relation( 

704 from_entity: str, 

705 to_entity: str, 

706 relation_type: str, 

707 properties: dict[str, Any] | None = None, 

708 ) -> str: 

709 """Create a relationship between two entities in the knowledge graph.""" 

710 return await _create_relation_impl( 

711 from_entity, to_entity, relation_type, properties 

712 ) 

713 

714 @mcp_server.tool() # type: ignore[misc] 

715 async def search_entities( 

716 query: str, 

717 entity_type: str | None = None, 

718 limit: int = 10, 

719 ) -> str: 

720 """Search for entities by name or observations.""" 

721 return await _search_entities_impl(query, entity_type, limit) 

722 

723 @mcp_server.tool() # type: ignore[misc] 

724 async def get_entity_relationships( 

725 entity_name: str, 

726 relation_type: str | None = None, 

727 direction: str = "both", 

728 ) -> str: 

729 """Get all relationships for a specific entity.""" 

730 return await _get_entity_relationships_impl( 

731 entity_name, relation_type, direction 

732 ) 

733 

734 @mcp_server.tool() # type: ignore[misc] 

735 async def find_path( 

736 from_entity: str, 

737 to_entity: str, 

738 max_depth: int = 5, 

739 ) -> str: 

740 """Find paths between two entities using DuckPGQ's SQL/PGQ graph queries.""" 

741 return await _find_path_impl(from_entity, to_entity, max_depth) 

742 

743 @mcp_server.tool() # type: ignore[misc] 

744 async def get_knowledge_graph_stats() -> str: 

745 """Get statistics about the knowledge graph.""" 

746 return await _get_knowledge_graph_stats_impl() 

747 

748 @mcp_server.tool() # type: ignore[misc] 

749 async def extract_entities_from_context( 

750 context: str, 

751 auto_create: bool = False, 

752 ) -> str: 

753 """Extract entities from conversation context using pattern matching.""" 

754 return await _extract_entities_from_context_impl(context, auto_create) 

755 

756 @mcp_server.tool() # type: ignore[misc] 

757 async def batch_create_entities(entities: list[dict[str, Any]]) -> str: 

758 """Bulk create multiple entities in one operation.""" 

759 return await _batch_create_entities_impl(entities)