Coverage for session_buddy / tools / entity_extraction_tools.py: 23.08%
35 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""MCP tools for multi-provider entity extraction and persistence."""
4from __future__ import annotations
6import typing as t
7from typing import TYPE_CHECKING
9from session_buddy.config.feature_flags import get_feature_flags
10from session_buddy.memory.entity_extractor import EntityExtractionEngine
11from session_buddy.memory.persistence import insert_processed_memory
13if TYPE_CHECKING:
14 from fastmcp import FastMCP
17async def extract_and_store_memory(
18 user_input: str,
19 ai_output: str,
20 project: str | None = None,
21 namespace: str = "default",
22 activity_score: float | None = None,
23) -> dict[str, t.Any]:
24 """Extract entities using cascade and persist to v2 tables (when enabled).
26 This is a module-level function that can be imported and called directly
27 by both the MCP tool and internal modules like app_monitor.
28 """
29 flags = get_feature_flags()
30 if not flags.enable_llm_entity_extraction or not flags.use_schema_v2:
31 return {
32 "status": "skipped",
33 "reason": "feature_disabled",
34 }
36 engine = EntityExtractionEngine()
37 result = await engine.extract_entities(user_input, ai_output)
39 # Persist into v2 tables
40 content = f"User: {user_input}\nAssistant: {ai_output}"
42 # Try to compute embedding using ReflectionDatabaseAdapter (optional)
43 embedding = None
44 try:
45 from session_buddy.adapters.reflection_adapter_oneiric import (
46 ReflectionDatabaseAdapterOneiric,
47 )
49 async with ReflectionDatabaseAdapterOneiric() as db:
50 embedding = await db._generate_embedding(content)
51 except Exception:
52 # Optional dependency or model not available; persist without embedding
53 embedding = None
55 # Activity-based importance scoring: blend LLM importance with activity
56 pm = result.processed_memory
57 if activity_score is not None:
58 from contextlib import suppress
60 with suppress(Exception):
61 act = max(0.0, min(1.0, float(activity_score)))
62 pm.importance_score = max(
63 0.0, min(1.0, 0.7 * pm.importance_score + 0.3 * act)
64 )
66 persist = insert_processed_memory(
67 pm,
68 content=content,
69 project=project,
70 namespace=namespace,
71 embedding=embedding,
72 )
74 # Log extraction provider usage in access log (for metrics)
75 from contextlib import suppress
77 with suppress(Exception):
78 from session_buddy.memory.persistence import log_memory_access
80 log_memory_access(
81 persist.memory_id, access_type=f"extract:{result.llm_provider}"
82 )
84 return {
85 "status": "ok",
86 "llm_provider": result.llm_provider,
87 "extraction_time_ms": result.extraction_time_ms,
88 "memory_id": persist.memory_id,
89 "entity_ids": persist.entity_ids,
90 "relationship_ids": persist.relationship_ids,
91 }
94def register_extraction_tools(mcp: FastMCP) -> None:
95 @mcp.tool() # type: ignore[no-untyped-call]
96 async def extract_and_store_memory_tool(
97 user_input: str,
98 ai_output: str,
99 project: str | None = None,
100 namespace: str = "default",
101 activity_score: float | None = None,
102 ) -> dict[str, t.Any]:
103 """Extract entities using cascade and persist to v2 tables (when enabled)."""
104 return await extract_and_store_memory(
105 user_input, ai_output, project, namespace, activity_score
106 )