Coverage for session_buddy / tools / knowledge_graph_tools.py: 37.86%
268 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Knowledge Graph MCP tools for semantic memory management.
4This module provides MCP tools for interacting with the DuckPGQ-based knowledge graph,
5enabling semantic memory through entity-relationship modeling.
7Refactored to use utility modules for reduced code duplication.
8"""
10from __future__ import annotations
12import re
13from typing import TYPE_CHECKING, Any
15from session_buddy.utils.error_handlers import _get_logger
16from session_buddy.utils.messages import ToolMessages
18if TYPE_CHECKING:
19 from collections.abc import Awaitable, Callable
21 from session_buddy.adapters.knowledge_graph_adapter import (
22 KnowledgeGraphDatabaseAdapter as KnowledgeGraphDatabase,
23 )
26# ============================================================================
27# Service Resolution
28# ============================================================================
31def _check_knowledge_graph_available() -> bool:
32 """Check if knowledge graph dependencies are available."""
33 try:
34 import duckdb
36 return True
37 except ImportError:
38 return False
41async def _require_knowledge_graph() -> KnowledgeGraphDatabase:
42 """Get knowledge graph database instance or raise error."""
43 try:
44 from session_buddy.adapters.knowledge_graph_adapter import (
45 KnowledgeGraphDatabaseAdapter,
46 )
47 from session_buddy.di import configure
49 configure()
50 kg = KnowledgeGraphDatabaseAdapter()
51 await kg.initialize()
52 return kg
53 except Exception as e:
54 msg = f"Knowledge graph not available: {e}"
55 raise RuntimeError(msg) from e
58async def _execute_kg_operation(
59 operation_name: str, operation: Callable[[Any], Awaitable[str]]
60) -> str:
61 """Execute a knowledge graph operation with error handling."""
62 try:
63 async with await _require_knowledge_graph() as kg:
64 return await operation(kg)
65 except RuntimeError as e:
66 return f"❌ {e!s}. Install dependencies: uv sync"
67 except Exception as e:
68 _get_logger().exception(f"Error in {operation_name}: {e}")
69 return ToolMessages.operation_failed(operation_name, e)
72# ============================================================================
73# Entity Extraction Patterns
74# ============================================================================
77ENTITY_PATTERNS = {
78 "project": r"\b([A-Z][a-z]+-[a-z]+(?:-[a-z]+)*)\b", # kebab-case projects
79 "library": r"\b(ACB|FastMCP|DuckDB|pytest|pydantic|uvicorn)\b",
80 "technology": r"\b(Python|JavaScript|TypeScript|Docker|Kubernetes)\b",
81 "concept": r"\b(dependency injection|semantic memory|property graph|vector search)\b",
82}
85# ============================================================================
86# Entity Operations
87# ============================================================================
90async def _create_entity_operation(
91 kg: Any,
92 name: str,
93 entity_type: str,
94 observations: list[str],
95 properties: dict[str, Any],
96) -> str:
97 """Create an entity in the knowledge graph."""
98 entity = await kg.create_entity(
99 name=name,
100 entity_type=entity_type,
101 observations=observations,
102 properties=properties,
103 )
105 lines = [
106 f"✅ Entity '{name}' created successfully!",
107 f"📊 Type: {entity_type}",
108 f"🆔 ID: {entity['id']}",
109 ]
111 if observations:
112 lines.append(f"📝 Observations: {len(observations)}")
113 if properties:
114 lines.append(f"⚙️ Properties: {', '.join(properties.keys())}")
116 _get_logger().info(
117 "Entity created",
118 entity_name=name,
119 entity_type=entity_type,
120 observations_count=len(observations),
121 )
122 return "\n".join(lines)
125async def _create_entity_impl(
126 name: str,
127 entity_type: str,
128 observations: list[str] | None = None,
129 properties: dict[str, Any] | None = None,
130) -> str:
131 """Create an entity in the knowledge graph."""
133 async def operation_wrapper(kg: Any) -> str:
134 return await _create_entity_operation(
135 kg, name, entity_type, observations or [], properties or {}
136 )
138 return await _execute_kg_operation(
139 "Create entity",
140 operation_wrapper,
141 )
144async def _add_observation_operation(
145 kg: Any, entity_name: str, observation: str
146) -> str:
147 """Add an observation (fact) to an existing entity."""
148 success = await kg.add_observation(entity_name, observation)
150 if not success:
151 return f"❌ Entity '{entity_name}' not found"
153 _get_logger().info(
154 "Observation added",
155 entity_name=entity_name,
156 observation=observation[:100],
157 )
158 return "\n".join(
159 [
160 f"✅ Observation added to '{entity_name}'",
161 f"📝 Observation: {observation}",
162 ]
163 )
166async def _add_observation_impl(entity_name: str, observation: str) -> str:
167 """Add an observation (fact) to an existing entity."""
169 async def operation_wrapper(kg: Any) -> str:
170 return await _add_observation_operation(kg, entity_name, observation)
172 return await _execute_kg_operation(
173 "Add observation",
174 operation_wrapper,
175 )
178# ============================================================================
179# Relationship Operations
180# ============================================================================
183async def _create_relation_operation(
184 kg: Any,
185 from_entity: str,
186 to_entity: str,
187 relation_type: str,
188 properties: dict[str, Any],
189) -> str:
190 """Create a relationship between two entities."""
191 relation = await kg.create_relation(
192 from_entity=from_entity,
193 to_entity=to_entity,
194 relation_type=relation_type,
195 properties=properties,
196 )
198 if not relation:
199 return f"❌ One or both entities not found: {from_entity}, {to_entity}"
201 lines = [
202 f"✅ Relationship created: {from_entity} --[{relation_type}]--> {to_entity}",
203 f"🆔 Relation ID: {relation['id']}",
204 ]
206 if properties:
207 lines.append(f"⚙️ Properties: {', '.join(properties.keys())}")
209 _get_logger().info(
210 "Relation created",
211 from_entity=from_entity,
212 to_entity=to_entity,
213 relation_type=relation_type,
214 )
215 return "\n".join(lines)
218async def _create_relation_impl(
219 from_entity: str,
220 to_entity: str,
221 relation_type: str,
222 properties: dict[str, Any] | None = None,
223) -> str:
224 """Create a relationship between two entities."""
226 async def operation_wrapper(kg: Any) -> str:
227 return await _create_relation_operation(
228 kg, from_entity, to_entity, relation_type, properties or {}
229 )
231 return await _execute_kg_operation(
232 "Create relation",
233 operation_wrapper,
234 )
237# ============================================================================
238# Search Operations
239# ============================================================================
242def _format_entity_result(entity: dict[str, Any]) -> list[str]:
243 """Format a single entity search result."""
244 lines = [f"📌 {entity['name']} ({entity['entity_type']})"]
246 observations = entity.get("observations")
247 if observations:
248 lines.append(f" 📝 Observations: {len(observations)}")
249 if observations: 249 ↛ 253line 249 didn't jump to line 253 because the condition on line 249 was always true
250 preview = observations[0]
251 lines.append(f" └─ {preview[:80]}{'...' if len(preview) > 80 else ''}")
253 lines.append("")
254 return lines
257def _format_batch_results(
258 created: list[str],
259 failed: list[tuple[str, str]],
260) -> list[str]:
261 """Format batch entity creation results."""
262 lines = [
263 "📦 Batch Entity Creation Results",
264 "",
265 f"Successfully Created: {len(created)}",
266 ]
268 if created:
269 max_show = 10
270 for name in created[:max_show]:
271 lines.append(f" • {name}")
272 remaining = len(created) - max_show
273 if remaining > 0:
274 lines.append(f" • and {remaining} more")
276 if failed:
277 lines.append("")
278 lines.append(f"Failed: {len(failed)}")
279 max_failed = 5
280 for name, error in failed[:max_failed]:
281 lines.append(f" • {name}: {error}")
282 remaining_failed = len(failed) - max_failed
283 if remaining_failed > 0:
284 lines.append(f" • and {remaining_failed} more")
286 return lines
289async def _search_entities_operation(
290 kg: Any, query: str, entity_type: str | None, limit: int
291) -> str:
292 """Search for entities by name or observations."""
293 results = await kg.search_entities(
294 query=query,
295 entity_type=entity_type,
296 limit=limit,
297 )
299 if not results:
300 return f"🔍 No entities found matching '{query}'"
302 lines = [f"🔍 Found {len(results)} entities matching '{query}':", ""]
304 for entity in results:
305 lines.extend(_format_entity_result(entity))
307 _get_logger().info(
308 "Entities searched",
309 query=query,
310 entity_type=entity_type,
311 results_count=len(results),
312 )
313 return "\n".join(lines)
316async def _search_entities_impl(
317 query: str,
318 entity_type: str | None = None,
319 limit: int = 10,
320) -> str:
321 """Search for entities by name or observations."""
323 async def operation_wrapper(kg: Any) -> str:
324 return await _search_entities_operation(kg, query, entity_type, limit)
326 return await _execute_kg_operation(
327 "Search entities",
328 operation_wrapper,
329 )
332def _format_relationship(rel: dict[str, Any], direction: str, entity_name: str) -> str:
333 """Format a single relationship based on direction."""
334 if direction == "outgoing" or (
335 direction == "both" and rel["from_entity"] == entity_name
336 ):
337 return (
338 f" {rel['from_entity']} --[{rel['relation_type']}]--> {rel['to_entity']}"
339 )
340 return f" {rel['from_entity']} <--[{rel['relation_type']}]-- {rel['to_entity']}"
343async def _get_entity_relationships_operation(
344 kg: Any, entity_name: str, relation_type: str | None, direction: str
345) -> str:
346 """Get all relationships for an entity."""
347 relationships = await kg.get_relationships(
348 entity_name=entity_name,
349 relation_type=relation_type,
350 direction=direction,
351 )
353 if not relationships:
354 return f"🔍 No relationships found for '{entity_name}'"
356 lines = [f"🔗 Found {len(relationships)} relationships for '{entity_name}':", ""]
358 for rel in relationships:
359 lines.append(_format_relationship(rel, direction, entity_name))
361 _get_logger().info(
362 "Relationships retrieved",
363 entity_name=entity_name,
364 relation_type=relation_type,
365 direction=direction,
366 count=len(relationships),
367 )
368 return "\n".join(lines)
371async def _get_entity_relationships_impl(
372 entity_name: str,
373 relation_type: str | None = None,
374 direction: str = "both",
375) -> str:
376 """Get all relationships for an entity."""
378 async def operation_wrapper(kg: Any) -> str:
379 return await _get_entity_relationships_operation(
380 kg, entity_name, relation_type, direction
381 )
383 return await _execute_kg_operation(
384 "Get entity relationships",
385 operation_wrapper,
386 )
389# ============================================================================
390# Path Finding
391# ============================================================================
394async def _find_path_operation(
395 kg: Any, from_entity: str, to_entity: str, max_depth: int
396) -> str:
397 """Find paths between two entities using SQL/PGQ."""
398 paths = await kg.find_path(
399 from_entity=from_entity,
400 to_entity=to_entity,
401 max_depth=max_depth,
402 )
404 if not paths:
405 return f"🔍 No path found between '{from_entity}' and '{to_entity}'"
407 lines = [
408 f"🛤️ Found {len(paths)} path(s) from '{from_entity}' to '{to_entity}':",
409 "",
410 ]
412 for i, path in enumerate(paths, 1):
413 lines.extend(
414 [
415 f"{i}. Path length: {path['path_length']} hop(s)",
416 f" {path['from_entity']} ➜ ... ➜ {path['to_entity']}",
417 "",
418 ]
419 )
421 _get_logger().info(
422 "Paths found",
423 from_entity=from_entity,
424 to_entity=to_entity,
425 paths_count=len(paths),
426 )
427 return "\n".join(lines)
430async def _find_path_impl(
431 from_entity: str,
432 to_entity: str,
433 max_depth: int = 5,
434) -> str:
435 """Find paths between two entities using SQL/PGQ."""
437 async def operation_wrapper(kg: Any) -> str:
438 return await _find_path_operation(kg, from_entity, to_entity, max_depth)
440 return await _execute_kg_operation(
441 "Find path",
442 operation_wrapper,
443 )
446# ============================================================================
447# Statistics
448# ============================================================================
451def _format_entity_types(entity_types: dict[str, int]) -> list[str]:
452 """Format entity type counts for statistics output."""
453 if not entity_types:
454 return []
456 lines = ["📊 Entity Types:"]
457 lines.extend(f" • {etype}: {count}" for etype, count in entity_types.items())
458 lines.append("")
459 return lines
462def _format_relationship_types(relationship_types: dict[str, int]) -> list[str]:
463 """Format relationship type counts for statistics output."""
464 if not relationship_types:
465 return []
467 lines = ["🔗 Relationship Types:"]
468 lines.extend(
469 f" • {rtype}: {count}" for rtype, count in relationship_types.items()
470 )
471 lines.append("")
472 return lines
475async def _get_knowledge_graph_stats_operation(kg: Any) -> str:
476 """Get knowledge graph statistics."""
477 stats = await kg.get_stats()
479 lines = [
480 "📊 Knowledge Graph Statistics",
481 "",
482 f"📌 Total Entities: {stats['total_entities']}",
483 f"🔗 Total Relationships: {stats['total_relationships']}",
484 "",
485 ]
487 # Entity types
488 entity_types = stats.get("entity_types", {})
489 lines.extend(_format_entity_types(entity_types))
491 # Relationship types
492 relationship_types = stats.get("relationship_types", {})
493 lines.extend(_format_relationship_types(relationship_types))
495 lines.extend(
496 [
497 f"💾 Database: {stats['database_path']}",
498 f"🔧 DuckPGQ: {'✅ Installed' if stats['duckpgq_installed'] else '❌ Not installed'}",
499 ]
500 )
502 _get_logger().info("Knowledge graph stats retrieved", **stats)
503 return "\n".join(lines)
506async def _get_knowledge_graph_stats_impl() -> str:
507 """Get knowledge graph statistics."""
508 return await _execute_kg_operation(
509 "Get KG stats", _get_knowledge_graph_stats_operation
510 )
513# ============================================================================
514# Entity Extraction
515# ============================================================================
518def _extract_patterns_from_context(context: str) -> dict[str, set[str]]:
519 """Extract entity patterns from context text."""
520 extracted: dict[str, set[str]] = {}
521 for entity_type, pattern in ENTITY_PATTERNS.items():
522 matches = re.findall(pattern, context, re.IGNORECASE)
523 if matches:
524 extracted[entity_type] = set(matches)
525 return extracted
528async def _auto_create_entity_if_new(
529 kg: Any, entity_name: str, entity_type: str
530) -> bool:
531 """Create entity if it doesn't exist. Returns True if created."""
532 existing = await kg.find_entity_by_name(entity_name)
533 if not existing:
534 await kg.create_entity(
535 name=entity_name,
536 entity_type=entity_type,
537 observations=["Extracted from conversation context"],
538 )
539 return True
540 return False
543async def _process_entity_type(
544 kg: Any,
545 entity_type: str,
546 entities: set[str],
547 auto_create: bool,
548) -> tuple[list[str], int, int]:
549 """Process entities of a specific type."""
550 lines = [f"📊 {entity_type.capitalize()}:"]
551 count = 0
552 created = 0
554 for entity_name in sorted(entities):
555 lines.append(f" • {entity_name}")
556 count += 1
557 if auto_create and await _auto_create_entity_if_new(
558 kg, entity_name, entity_type
559 ):
560 created += 1
562 lines.append("")
563 return lines, count, created
566async def _extract_entities_from_context_impl(
567 context: str,
568 auto_create: bool = False,
569) -> str:
570 """Extract entities from conversation context using pattern matching."""
572 async def operation(kg: Any) -> str:
573 extracted = _extract_patterns_from_context(context)
574 if not extracted:
575 return "🔍 No entities detected in context"
577 lines = ["🔍 Extracted Entities from Context:", ""]
578 total_extracted = 0
579 created_count = 0
581 for entity_type, entities in extracted.items():
582 type_lines, count, created = await _process_entity_type(
583 kg, entity_type, entities, auto_create
584 )
585 lines.extend(type_lines)
586 total_extracted += count
587 created_count += created
589 lines.append(f"📊 Total Extracted: {total_extracted}")
590 if auto_create:
591 lines.append(f"✅ Auto-created: {created_count} new entities")
593 _get_logger().info(
594 "Entities extracted from context",
595 total_extracted=total_extracted,
596 auto_created=created_count if auto_create else 0,
597 )
598 return "\n".join(lines)
600 return await _execute_kg_operation("Extract entities from context", operation)
603# ============================================================================
604# Batch Operations
605# ============================================================================
608async def _create_single_entity(
609 kg: Any, entity_data: dict[str, Any]
610) -> tuple[str | None, tuple[str, str] | None]:
611 """Create a single entity. Returns (created_name, None) or (None, (name, error))."""
612 try:
613 entity = await kg.create_entity(
614 name=entity_data["name"],
615 entity_type=entity_data["entity_type"],
616 observations=entity_data.get("observations", []),
617 properties=entity_data.get("properties", {}),
618 )
619 return entity["name"], None
620 except Exception as e:
621 return None, (entity_data["name"], str(e))
624async def _batch_create_entities_operation(
625 kg: Any, entities: list[dict[str, Any]]
626) -> str:
627 """Bulk create multiple entities."""
628 created = []
629 failed = []
631 for entity_data in entities:
632 created_name, failure = await _create_single_entity(kg, entity_data)
633 if created_name:
634 created.append(created_name)
635 elif failure:
636 failed.append(failure)
638 lines = [
639 "📦 Batch Entity Creation Results:",
640 "",
641 f"✅ Successfully Created: {len(created)}",
642 ]
644 if created:
645 for name in created[:10]: # Show first 10
646 lines.append(f" • {name}")
647 if len(created) > 10:
648 lines.append(f" ... and {len(created) - 10} more")
649 lines.append("")
651 if failed:
652 lines.append(f"❌ Failed: {len(failed)}")
653 for name, error in failed[:5]: # Show first 5 failures
654 lines.append(f" • {name}: {error}")
655 if len(failed) > 5:
656 lines.append(f" ... and {len(failed) - 5} more")
658 _get_logger().info(
659 "Batch entities created",
660 total=len(entities),
661 created=len(created),
662 failed=len(failed),
663 )
664 return "\n".join(lines)
667async def _batch_create_entities_impl(entities: list[dict[str, Any]]) -> str:
668 """Bulk create multiple entities."""
670 async def operation_wrapper(kg: Any) -> str:
671 return await _batch_create_entities_operation(kg, entities)
673 return await _execute_kg_operation(
674 "Batch create entities",
675 operation_wrapper,
676 )
679# ============================================================================
680# MCP Tool Registration
681# ============================================================================
684def register_knowledge_graph_tools(mcp_server: Any) -> None:
685 """Register all knowledge graph MCP tools with the server."""
687 @mcp_server.tool() # type: ignore[misc]
688 async def create_entity(
689 name: str,
690 entity_type: str,
691 observations: list[str] | None = None,
692 properties: dict[str, Any] | None = None,
693 ) -> str:
694 """Create an entity (node) in the knowledge graph."""
695 return await _create_entity_impl(name, entity_type, observations, properties)
697 @mcp_server.tool() # type: ignore[misc]
698 async def add_observation(entity_name: str, observation: str) -> str:
699 """Add an observation (fact) to an existing entity."""
700 return await _add_observation_impl(entity_name, observation)
702 @mcp_server.tool() # type: ignore[misc]
703 async def create_relation(
704 from_entity: str,
705 to_entity: str,
706 relation_type: str,
707 properties: dict[str, Any] | None = None,
708 ) -> str:
709 """Create a relationship between two entities in the knowledge graph."""
710 return await _create_relation_impl(
711 from_entity, to_entity, relation_type, properties
712 )
714 @mcp_server.tool() # type: ignore[misc]
715 async def search_entities(
716 query: str,
717 entity_type: str | None = None,
718 limit: int = 10,
719 ) -> str:
720 """Search for entities by name or observations."""
721 return await _search_entities_impl(query, entity_type, limit)
723 @mcp_server.tool() # type: ignore[misc]
724 async def get_entity_relationships(
725 entity_name: str,
726 relation_type: str | None = None,
727 direction: str = "both",
728 ) -> str:
729 """Get all relationships for a specific entity."""
730 return await _get_entity_relationships_impl(
731 entity_name, relation_type, direction
732 )
734 @mcp_server.tool() # type: ignore[misc]
735 async def find_path(
736 from_entity: str,
737 to_entity: str,
738 max_depth: int = 5,
739 ) -> str:
740 """Find paths between two entities using DuckPGQ's SQL/PGQ graph queries."""
741 return await _find_path_impl(from_entity, to_entity, max_depth)
743 @mcp_server.tool() # type: ignore[misc]
744 async def get_knowledge_graph_stats() -> str:
745 """Get statistics about the knowledge graph."""
746 return await _get_knowledge_graph_stats_impl()
748 @mcp_server.tool() # type: ignore[misc]
749 async def extract_entities_from_context(
750 context: str,
751 auto_create: bool = False,
752 ) -> str:
753 """Extract entities from conversation context using pattern matching."""
754 return await _extract_entities_from_context_impl(context, auto_create)
756 @mcp_server.tool() # type: ignore[misc]
757 async def batch_create_entities(entities: list[dict[str, Any]]) -> str:
758 """Bulk create multiple entities in one operation."""
759 return await _batch_create_entities_impl(entities)