Coverage for mcp_bridge/tools/semantic_search.py: 41%
1469 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-10 00:20 -0500
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-10 00:20 -0500
1"""
2Semantic Code Search - Vector-based code understanding
4Uses ChromaDB for persistent vector storage with multiple embedding providers:
5- Ollama (local, free) - nomic-embed-text (768 dims)
6- Mxbai (local, free) - mxbai-embed-large (1024 dims, better for code)
7- Gemini (cloud, OAuth) - gemini-embedding-001 (768-3072 dims)
8- OpenAI (cloud, OAuth) - text-embedding-3-small (1536 dims)
9- HuggingFace (cloud, token) - sentence-transformers/all-mpnet-base-v2 (768 dims)
11Enables natural language queries like "find authentication logic" without
12requiring exact pattern matching.
14Architecture:
15- Per-project ChromaDB storage at ~/.stravinsky/vectordb/<project_hash>/
16- Lazy initialization on first query
17- Provider abstraction for embedding generation
18- Chunking strategy: function/class level with context
19"""
21import atexit
22import hashlib
23import logging
24import sys
25import threading
26from abc import ABC, abstractmethod
27from pathlib import Path
28from typing import Literal
30import httpx
31from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
33from mcp_bridge.auth.token_store import TokenStore
34from mcp_bridge.tools.query_classifier import QueryCategory, classify_query
36logger = logging.getLogger(__name__)
39# Lazy imports for watchdog (avoid startup cost)
40_watchdog = None
41_watchdog_import_lock = threading.Lock()
44def get_watchdog():
45 """Lazy import of watchdog components for file watching."""
46 global _watchdog
47 if _watchdog is None:
48 with _watchdog_import_lock:
49 if _watchdog is None:
50 from watchdog.events import FileSystemEventHandler
51 from watchdog.observers import Observer
53 _watchdog = {"Observer": Observer, "FileSystemEventHandler": FileSystemEventHandler}
54 return _watchdog
57# Embedding provider type
58EmbeddingProvider = Literal["ollama", "mxbai", "gemini", "openai", "huggingface"]
60# Lazy imports to avoid startup cost
61_chromadb = None
62_ollama = None
63_httpx = None
64_filelock = None
65_import_lock = threading.Lock()
68def get_filelock():
69 global _filelock
70 if _filelock is None:
71 with _import_lock:
72 if _filelock is None:
73 import filelock
75 _filelock = filelock
76 return _filelock
79def get_chromadb():
80 global _chromadb
81 if _chromadb is None:
82 with _import_lock:
83 if _chromadb is None:
84 try:
85 import chromadb
87 _chromadb = chromadb
88 except ImportError as e:
89 import sys
91 if sys.version_info >= (3, 14):
92 raise ImportError(
93 "ChromaDB is not available on Python 3.14+. "
94 "Semantic search is not supported on Python 3.14 yet. "
95 "Use Python 3.11-3.13 for semantic search features."
96 ) from e
97 raise
98 return _chromadb
101def get_ollama():
102 global _ollama
103 if _ollama is None:
104 with _import_lock:
105 if _ollama is None:
106 import ollama
108 _ollama = ollama
109 return _ollama
112def get_httpx():
113 global _httpx
114 if _httpx is None:
115 with _import_lock:
116 if _httpx is None:
117 import httpx
119 _httpx = httpx
120 return _httpx
123# ========================
124# GITIGNORE MANAGER
125# ========================
127# Lazy import for pathspec
128_pathspec = None
129_pathspec_lock = threading.Lock()
132def get_pathspec():
133 """Lazy import of pathspec for gitignore pattern matching."""
134 global _pathspec
135 if _pathspec is None:
136 with _pathspec_lock:
137 if _pathspec is None:
138 import pathspec
140 _pathspec = pathspec
141 return _pathspec
144class GitIgnoreManager:
145 """Manages .gitignore and .stravignore pattern matching.
147 Loads and caches gitignore-style patterns from:
148 - .gitignore (standard git ignore patterns)
149 - .stravignore (Stravinsky-specific ignore patterns)
151 Patterns are combined and cached per project for efficient matching.
152 The manager automatically reloads patterns if the ignore files are modified.
153 """
155 # Cache of GitIgnoreManager instances per project path
156 _instances: dict[str, "GitIgnoreManager"] = {}
157 _instances_lock = threading.Lock()
159 @classmethod
160 def get_instance(cls, project_path: Path) -> "GitIgnoreManager":
161 """Get or create a GitIgnoreManager for a project.
163 Args:
164 project_path: Root path of the project
166 Returns:
167 Cached GitIgnoreManager instance for the project
168 """
169 path_str = str(project_path.resolve())
170 if path_str not in cls._instances:
171 with cls._instances_lock:
172 if path_str not in cls._instances:
173 cls._instances[path_str] = cls(project_path)
174 return cls._instances[path_str]
176 @classmethod
177 def clear_cache(cls, project_path: Path | None = None) -> None:
178 """Clear cached GitIgnoreManager instances.
180 Args:
181 project_path: Clear only this project's cache, or all if None
182 """
183 with cls._instances_lock:
184 if project_path is None:
185 cls._instances.clear()
186 else:
187 path_str = str(project_path.resolve())
188 cls._instances.pop(path_str, None)
190 def __init__(self, project_path: Path):
191 """Initialize the GitIgnoreManager.
193 Args:
194 project_path: Root path of the project
195 """
196 self.project_path = project_path.resolve()
197 self._spec = None
198 self._gitignore_mtime: float | None = None
199 self._stravignore_mtime: float | None = None
200 self._lock = threading.Lock()
202 def _get_file_mtime(self, file_path: Path) -> float | None:
203 """Get modification time of a file, or None if it doesn't exist."""
204 try:
205 return file_path.stat().st_mtime
206 except (OSError, FileNotFoundError):
207 return None
209 def _needs_reload(self) -> bool:
210 """Check if ignore patterns need to be reloaded."""
211 gitignore_path = self.project_path / ".gitignore"
212 stravignore_path = self.project_path / ".stravignore"
214 current_gitignore_mtime = self._get_file_mtime(gitignore_path)
215 current_stravignore_mtime = self._get_file_mtime(stravignore_path)
217 # Check if either file has been modified or if we haven't loaded yet
218 if self._spec is None:
219 return True
221 if current_gitignore_mtime != self._gitignore_mtime:
222 return True
224 if current_stravignore_mtime != self._stravignore_mtime:
225 return True
227 return False
229 def _load_patterns(self) -> None:
230 """Load patterns from .gitignore and .stravignore files."""
231 pathspec = get_pathspec()
233 patterns = []
234 gitignore_path = self.project_path / ".gitignore"
235 stravignore_path = self.project_path / ".stravignore"
237 # Load .gitignore patterns
238 if gitignore_path.exists():
239 try:
240 with open(gitignore_path, encoding="utf-8") as f:
241 patterns.extend(f.read().splitlines())
242 self._gitignore_mtime = self._get_file_mtime(gitignore_path)
243 logger.debug(f"Loaded .gitignore from {gitignore_path}")
244 except Exception as e:
245 logger.warning(f"Failed to load .gitignore: {e}")
246 self._gitignore_mtime = None
247 else:
248 self._gitignore_mtime = None
250 # Load .stravignore patterns
251 if stravignore_path.exists():
252 try:
253 with open(stravignore_path, encoding="utf-8") as f:
254 patterns.extend(f.read().splitlines())
255 self._stravignore_mtime = self._get_file_mtime(stravignore_path)
256 logger.debug(f"Loaded .stravignore from {stravignore_path}")
257 except Exception as e:
258 logger.warning(f"Failed to load .stravignore: {e}")
259 self._stravignore_mtime = None
260 else:
261 self._stravignore_mtime = None
263 # Filter out empty lines and comments
264 patterns = [p for p in patterns if p.strip() and not p.strip().startswith("#")]
266 # Create pathspec matcher
267 self._spec = pathspec.PathSpec.from_lines("gitwildmatch", patterns)
268 logger.debug(f"Loaded {len(patterns)} ignore patterns for {self.project_path}")
270 @property
271 def spec(self):
272 """Get the PathSpec matcher, reloading if necessary."""
273 with self._lock:
274 if self._needs_reload():
275 self._load_patterns()
276 return self._spec
278 def is_ignored(self, file_path: Path) -> bool:
279 """Check if a file path should be ignored.
281 Args:
282 file_path: Absolute or relative path to check
284 Returns:
285 True if the file matches any ignore pattern, False otherwise
286 """
287 try:
288 # Convert to relative path from project root
289 if file_path.is_absolute():
290 rel_path = file_path.resolve().relative_to(self.project_path)
291 else:
292 rel_path = file_path
294 # pathspec expects forward slashes and string paths
295 rel_path_str = str(rel_path).replace("\\", "/")
297 # Check against patterns
298 return self.spec.match_file(rel_path_str)
299 except ValueError:
300 # Path is outside project - not ignored by gitignore (but may be ignored for other reasons)
301 return False
302 except Exception as e:
303 logger.warning(f"Error checking ignore status for {file_path}: {e}")
304 return False
306 def filter_paths(self, paths: list[Path]) -> list[Path]:
307 """Filter a list of paths, removing ignored ones.
309 Args:
310 paths: List of paths to filter
312 Returns:
313 List of paths that are not ignored
314 """
315 return [p for p in paths if not self.is_ignored(p)]
318# ========================
319# EMBEDDING PROVIDERS
320# ========================
323class BaseEmbeddingProvider(ABC):
324 """Abstract base class for embedding providers."""
326 @abstractmethod
327 async def get_embedding(self, text: str) -> list[float]:
328 """Get embedding vector for text."""
329 pass
331 @abstractmethod
332 async def check_available(self) -> bool:
333 """Check if the provider is available and ready."""
334 pass
336 @property
337 @abstractmethod
338 def dimension(self) -> int:
339 """Return the embedding dimension for this provider."""
340 pass
342 @property
343 @abstractmethod
344 def name(self) -> str:
345 """Return the provider name."""
346 pass
349class OllamaProvider(BaseEmbeddingProvider):
350 """Ollama local embedding provider using nomic-embed-text."""
352 MODEL = "nomic-embed-text"
353 DIMENSION = 768
355 def __init__(self):
356 self._available: bool | None = None
358 @property
359 def dimension(self) -> int:
360 return self.DIMENSION
362 @property
363 def name(self) -> str:
364 return "ollama"
366 async def check_available(self) -> bool:
367 if self._available is not None:
368 return self._available
370 try:
371 ollama = get_ollama()
372 models = ollama.list()
373 model_names = [m.model for m in models.models] if hasattr(models, "models") else []
375 if not any(name and self.MODEL in name for name in model_names):
376 print(
377 f"⚠️ Embedding model '{self.MODEL}' not found. Run: ollama pull {self.MODEL}",
378 file=sys.stderr,
379 )
380 self._available = False
381 return False
383 self._available = True
384 return True
385 except Exception as e:
386 print(f"⚠️ Ollama not available: {e}. Start with: ollama serve", file=sys.stderr)
387 self._available = False
388 return False
390 async def get_embedding(self, text: str) -> list[float]:
391 ollama = get_ollama()
392 # nomic-embed-text has 8192 token context. Code can be 1-2 chars/token.
393 # Truncate to 2000 chars (~1000-2000 tokens) for larger safety margin
394 truncated = text[:2000] if len(text) > 2000 else text
395 response = ollama.embeddings(model=self.MODEL, prompt=truncated)
396 return response["embedding"]
399class GeminiProvider(BaseEmbeddingProvider):
400 """Gemini embedding provider using OAuth authentication."""
402 MODEL = "gemini-embedding-001"
403 DIMENSION = 768 # Using 768 for efficiency, can be up to 3072
405 def __init__(self):
406 self._available: bool | None = None
407 self._token_store = None
409 def _get_token_store(self):
410 if self._token_store is None:
411 from ..auth.token_store import TokenStore
413 self._token_store = TokenStore()
414 return self._token_store
416 @property
417 def dimension(self) -> int:
418 return self.DIMENSION
420 @property
421 def name(self) -> str:
422 return "gemini"
424 async def check_available(self) -> bool:
425 if self._available is not None:
426 return self._available
428 try:
429 token_store = self._get_token_store()
430 access_token = token_store.get_access_token("gemini")
432 if not access_token:
433 print(
434 "⚠️ Gemini not authenticated. Run: stravinsky-auth login gemini",
435 file=sys.stderr,
436 )
437 self._available = False
438 return False
440 self._available = True
441 return True
442 except Exception as e:
443 print(f"⚠️ Gemini not available: {e}", file=sys.stderr)
444 self._available = False
445 return False
447 async def get_embedding(self, text: str) -> list[float]:
448 import os
450 from ..auth.oauth import (
451 ANTIGRAVITY_DEFAULT_PROJECT_ID,
452 ANTIGRAVITY_ENDPOINTS,
453 ANTIGRAVITY_HEADERS,
454 )
456 token_store = self._get_token_store()
457 access_token = token_store.get_access_token("gemini")
459 if not access_token:
460 raise ValueError("Not authenticated with Gemini. Run: stravinsky-auth login gemini")
462 httpx = get_httpx()
464 # Use Antigravity endpoint for embeddings (same auth as invoke_gemini)
465 project_id = os.getenv("STRAVINSKY_ANTIGRAVITY_PROJECT_ID", ANTIGRAVITY_DEFAULT_PROJECT_ID)
467 headers = {
468 "Authorization": f"Bearer {access_token}",
469 "Content-Type": "application/json",
470 **ANTIGRAVITY_HEADERS,
471 }
473 # Wrap request for Antigravity API
474 import uuid
476 inner_payload = {
477 "model": f"models/{self.MODEL}",
478 "content": {"parts": [{"text": text}]},
479 "outputDimensionality": self.DIMENSION,
480 }
482 wrapped_payload = {
483 "project": project_id,
484 "model": self.MODEL,
485 "userAgent": "antigravity",
486 "requestId": f"embed-{uuid.uuid4()}",
487 "request": inner_payload,
488 }
490 # Try endpoints in order
491 last_error = None
492 async with httpx.AsyncClient(timeout=60.0) as client:
493 for endpoint in ANTIGRAVITY_ENDPOINTS:
494 api_url = f"{endpoint}/v1internal:embedContent"
496 try:
497 response = await client.post(
498 api_url,
499 headers=headers,
500 json=wrapped_payload,
501 )
503 if response.status_code in (401, 403):
504 last_error = Exception(f"{response.status_code} from {endpoint}")
505 continue
507 response.raise_for_status()
508 data = response.json()
510 # Extract embedding from response
511 inner_response = data.get("response", data)
512 embedding = inner_response.get("embedding", {})
513 values = embedding.get("values", [])
515 if values:
516 return values
518 raise ValueError(f"No embedding values in response: {data}")
520 except Exception as e:
521 last_error = e
522 continue
524 raise ValueError(f"All Antigravity endpoints failed for embeddings: {last_error}")
527class OpenAIProvider(BaseEmbeddingProvider):
528 """OpenAI embedding provider using OAuth authentication."""
530 MODEL = "text-embedding-3-small"
531 DIMENSION = 1536
533 def __init__(self):
534 self._available: bool | None = None
535 self._token_store = None
537 def _get_token_store(self):
538 if self._token_store is None:
539 from ..auth.token_store import TokenStore
541 self._token_store = TokenStore()
542 return self._token_store
544 @property
545 def dimension(self) -> int:
546 return self.DIMENSION
548 @property
549 def name(self) -> str:
550 return "openai"
552 async def check_available(self) -> bool:
553 if self._available is not None:
554 return self._available
556 try:
557 token_store = self._get_token_store()
558 access_token = token_store.get_access_token("openai")
560 if not access_token:
561 print(
562 "⚠️ OpenAI not authenticated. Run: stravinsky-auth login openai",
563 file=sys.stderr,
564 )
565 self._available = False
566 return False
568 self._available = True
569 return True
570 except Exception as e:
571 print(f"⚠️ OpenAI not available: {e}", file=sys.stderr)
572 self._available = False
573 return False
575 async def get_embedding(self, text: str) -> list[float]:
576 token_store = self._get_token_store()
577 access_token = token_store.get_access_token("openai")
579 if not access_token:
580 raise ValueError("Not authenticated with OpenAI. Run: stravinsky-auth login openai")
582 httpx = get_httpx()
584 # Use standard OpenAI API for embeddings
585 api_url = "https://api.openai.com/v1/embeddings"
587 headers = {
588 "Authorization": f"Bearer {access_token}",
589 "Content-Type": "application/json",
590 }
592 payload = {
593 "model": self.MODEL,
594 "input": text,
595 }
597 async with httpx.AsyncClient(timeout=60.0) as client:
598 response = await client.post(api_url, headers=headers, json=payload)
600 if response.status_code == 401:
601 raise ValueError("OpenAI authentication failed. Run: stravinsky-auth login openai")
603 response.raise_for_status()
604 data = response.json()
606 # Extract embedding from response
607 embeddings = data.get("data", [])
608 if embeddings and "embedding" in embeddings[0]:
609 return embeddings[0]["embedding"]
611 raise ValueError(f"No embedding in response: {data}")
614class MxbaiProvider(BaseEmbeddingProvider):
615 """Ollama local embedding provider using mxbai-embed-large (better for code).
617 mxbai-embed-large is a 1024-dimensional model optimized for code understanding.
618 It generally outperforms nomic-embed-text on code-related retrieval tasks.
619 """
621 MODEL = "mxbai-embed-large"
622 DIMENSION = 1024
624 def __init__(self):
625 self._available: bool | None = None
627 @property
628 def dimension(self) -> int:
629 return self.DIMENSION
631 @property
632 def name(self) -> str:
633 return "mxbai"
635 async def check_available(self) -> bool:
636 if self._available is not None:
637 return self._available
639 try:
640 ollama = get_ollama()
641 models = ollama.list()
642 model_names = [m.model for m in models.models] if hasattr(models, "models") else []
644 if not any(name and self.MODEL in name for name in model_names):
645 print(
646 f"⚠️ Embedding model '{self.MODEL}' not found. Run: ollama pull {self.MODEL}",
647 file=sys.stderr,
648 )
649 self._available = False
650 return False
652 self._available = True
653 return True
654 except Exception as e:
655 print(f"⚠️ Ollama not available: {e}. Start with: ollama serve", file=sys.stderr)
656 self._available = False
657 return False
659 async def get_embedding(self, text: str) -> list[float]:
660 ollama = get_ollama()
661 # mxbai-embed-large has 512 token context. Code can be 1-2 chars/token.
662 # Truncate to 2000 chars (~1000-2000 tokens) for safety margin
663 truncated = text[:2000] if len(text) > 2000 else text
664 response = ollama.embeddings(model=self.MODEL, prompt=truncated)
665 return response["embedding"]
668class HuggingFaceProvider(BaseEmbeddingProvider):
669 """Hugging Face Inference API embedding provider.
671 Uses the Hugging Face Inference API for embeddings. Requires HF_TOKEN from:
672 1. Environment variable: HF_TOKEN or HUGGING_FACE_HUB_TOKEN
673 2. HF CLI config: ~/.cache/huggingface/token or ~/.huggingface/token
675 Default model: sentence-transformers/all-mpnet-base-v2 (768 dims, high quality)
676 """
678 DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
679 DEFAULT_DIMENSION = 768
681 def __init__(self, model: str | None = None):
682 self._available: bool | None = None
683 self._model = model or self.DEFAULT_MODEL
684 # Dimension varies by model, but we'll use default for common models
685 self._dimension = self.DEFAULT_DIMENSION
686 self._token: str | None = None
688 @property
689 def dimension(self) -> int:
690 return self._dimension
692 @property
693 def name(self) -> str:
694 return "huggingface"
696 def _get_hf_token(self) -> str | None:
697 """Discover HF token from environment or CLI config."""
698 import os
700 # Check environment variables first
701 token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
702 if token:
703 return token
705 # Check HF CLI config locations
706 hf_token_paths = [
707 Path.home() / ".cache" / "huggingface" / "token",
708 Path.home() / ".huggingface" / "token",
709 ]
711 for token_path in hf_token_paths:
712 if token_path.exists():
713 try:
714 return token_path.read_text().strip()
715 except Exception:
716 continue
718 return None
720 async def check_available(self) -> bool:
721 if self._available is not None:
722 return self._available
724 try:
725 self._token = self._get_hf_token()
726 if not self._token:
727 print(
728 "⚠️ Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN env var",
729 file=sys.stderr,
730 )
731 self._available = False
732 return False
734 self._available = True
735 return True
736 except Exception as e:
737 print(f"⚠️ Hugging Face not available: {e}", file=sys.stderr)
738 self._available = False
739 return False
741 @retry(
742 stop=stop_after_attempt(3),
743 wait=wait_exponential(multiplier=1, min=2, max=10),
744 retry=retry_if_exception_type(httpx.HTTPStatusError),
745 )
746 async def get_embedding(self, text: str) -> list[float]:
747 """Get embedding from HF Inference API with retry logic."""
748 if not self._token:
749 self._token = self._get_hf_token()
750 if not self._token:
751 raise ValueError(
752 "Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN"
753 )
755 httpx_client = get_httpx()
757 # HF Serverless Inference API endpoint
758 # Note: Free tier may have limited availability for some models
759 api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self._model}"
761 headers = {
762 "Authorization": f"Bearer {self._token}",
763 }
765 # Truncate text to reasonable length (most models have 512 token limit)
766 # ~2000 chars ≈ 500 tokens for safety
767 truncated = text[:2000] if len(text) > 2000 else text
769 # HF Inference API accepts raw JSON with inputs field
770 payload = {"inputs": [truncated], "options": {"wait_for_model": True}}
772 async with httpx_client.AsyncClient(timeout=60.0) as client:
773 response = await client.post(api_url, headers=headers, json=payload)
775 # Handle specific error codes
776 if response.status_code == 401:
777 raise ValueError(
778 "Hugging Face authentication failed. Run: huggingface-cli login or set HF_TOKEN"
779 )
780 elif response.status_code == 410:
781 # Model removed from free tier
782 raise ValueError(
783 f"Model {self._model} is no longer available on HF free Inference API (410 Gone). "
784 "Try a different model or use Ollama for local embeddings instead."
785 )
786 elif response.status_code == 503:
787 # Model loading - retry will handle this
788 logger.info(f"Model {self._model} is loading, retrying...")
789 response.raise_for_status()
790 elif response.status_code == 429:
791 # Rate limit - retry will handle with exponential backoff
792 logger.warning("HF API rate limit hit, retrying with backoff...")
793 response.raise_for_status()
795 response.raise_for_status()
797 # Response is a single embedding vector (list of floats)
798 embedding = response.json()
800 # Handle different response formats
801 if isinstance(embedding, list):
802 # Direct embedding or batch with single item
803 if isinstance(embedding[0], (int, float)):
804 return embedding
805 elif isinstance(embedding[0], list):
806 # Batch response with single embedding
807 return embedding[0]
809 raise ValueError(f"Unexpected response format from HF API: {type(embedding)}")
811 async def embed_batch(self, texts: list[str]) -> list[list[float]]:
812 """Batch embedding support for HF API.
814 HF API supports batch requests, so we can send multiple texts at once.
815 """
816 if not texts:
817 return []
819 if not self._token:
820 self._token = self._get_hf_token()
821 if not self._token:
822 raise ValueError(
823 "Hugging Face token not found. Run: huggingface-cli login or set HF_TOKEN"
824 )
826 httpx_client = get_httpx()
828 # HF Serverless Inference API endpoint
829 api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{self._model}"
831 headers = {
832 "Authorization": f"Bearer {self._token}",
833 }
835 # Truncate all texts
836 truncated_texts = [text[:2000] if len(text) > 2000 else text for text in texts]
838 payload = {"inputs": truncated_texts, "options": {"wait_for_model": True}}
840 async with httpx_client.AsyncClient(timeout=120.0) as client:
841 response = await client.post(api_url, headers=headers, json=payload)
843 if response.status_code == 401:
844 raise ValueError(
845 "Hugging Face authentication failed. Run: huggingface-cli login or set HF_TOKEN"
846 )
848 response.raise_for_status()
850 embeddings = response.json()
852 # Response should be a list of embeddings
853 if isinstance(embeddings, list) and all(isinstance(e, list) for e in embeddings):
854 return embeddings
856 raise ValueError(f"Unexpected batch response format from HF API: {type(embeddings)}")
859# Embedding provider instance cache
860_embedding_provider_cache: dict[str, BaseEmbeddingProvider] = {}
861_embedding_provider_lock = threading.Lock()
864def get_embedding_provider(provider: EmbeddingProvider) -> BaseEmbeddingProvider:
865 """Factory function to get an embedding provider instance with caching."""
866 if provider not in _embedding_provider_cache:
867 with _embedding_provider_lock:
868 # Double-check pattern to avoid race condition
869 if provider not in _embedding_provider_cache:
870 providers = {
871 "ollama": OllamaProvider,
872 "mxbai": MxbaiProvider,
873 "gemini": GeminiProvider,
874 "openai": OpenAIProvider,
875 "huggingface": HuggingFaceProvider,
876 }
878 if provider not in providers:
879 raise ValueError(
880 f"Unknown provider: {provider}. Available: {list(providers.keys())}"
881 )
883 _embedding_provider_cache[provider] = providers[provider]()
885 return _embedding_provider_cache[provider]
888class CodebaseVectorStore:
889 """
890 Persistent vector store for a single codebase.
892 Storage: ~/.stravinsky/vectordb/<project_hash>_<provider>/
893 Embedding: Configurable via provider (ollama, gemini, openai)
894 """
896 CHUNK_SIZE = 50 # lines per chunk
897 CHUNK_OVERLAP = 10 # lines of overlap between chunks
899 # File patterns to index
900 CODE_EXTENSIONS = {
901 ".py",
902 ".js",
903 ".ts",
904 ".tsx",
905 ".jsx",
906 ".go",
907 ".rs",
908 ".rb",
909 ".java",
910 ".c",
911 ".cpp",
912 ".h",
913 ".hpp",
914 ".cs",
915 ".swift",
916 ".kt",
917 ".scala",
918 ".vue",
919 ".svelte",
920 ".md",
921 ".txt",
922 ".yaml",
923 ".yml",
924 ".json",
925 ".toml",
926 }
928 # Directories to skip (non-code related)
929 SKIP_DIRS = {
930 # Python
931 "__pycache__",
932 ".venv",
933 "venv",
934 "env",
935 ".env",
936 "virtualenv",
937 ".virtualenv",
938 ".tox",
939 ".nox",
940 ".pytest_cache",
941 ".mypy_cache",
942 ".ruff_cache",
943 ".pytype",
944 ".pyre",
945 "*.egg-info",
946 ".eggs",
947 "pip-wheel-metadata",
948 # Node.js
949 "node_modules",
950 ".npm",
951 ".yarn",
952 ".pnpm-store",
953 "bower_components",
954 # Build outputs
955 "dist",
956 "build",
957 "out",
958 "_build",
959 ".next",
960 ".nuxt",
961 ".output",
962 ".cache",
963 ".parcel-cache",
964 ".turbo",
965 # Version control
966 ".git",
967 ".svn",
968 ".hg",
969 # IDE/Editor
970 ".idea",
971 ".vscode",
972 ".vs",
973 # Test/coverage
974 "coverage",
975 "htmlcov",
976 ".coverage",
977 ".nyc_output",
978 # Rust/Go/Java
979 "target",
980 "vendor",
981 "Godeps",
982 # Misc
983 ".stravinsky",
984 "scratches",
985 "consoles",
986 "logs",
987 "tmp",
988 "temp",
989 }
991 @staticmethod
992 def _normalize_project_path(path: str) -> Path:
993 """
994 Normalize project path to git root if available.
996 This ensures one index per repo regardless of invocation directory.
997 If not a git repo, returns resolved absolute path.
998 """
999 import subprocess
1001 resolved = Path(path).resolve()
1003 # Try to find git root
1004 try:
1005 result = subprocess.run(
1006 ["git", "-C", str(resolved), "rev-parse", "--show-toplevel"],
1007 capture_output=True,
1008 text=True,
1009 timeout=2,
1010 check=False,
1011 )
1012 if result.returncode == 0:
1013 git_root = Path(result.stdout.strip())
1014 logger.debug(f"Normalized {resolved} → {git_root} (git root)")
1015 return git_root
1016 except (subprocess.TimeoutExpired, FileNotFoundError):
1017 pass
1019 # Not a git repo or git not available, use resolved path
1020 return resolved
1022 def __init__(self, project_path: str, provider: EmbeddingProvider = "ollama"):
1023 self.project_path = self._normalize_project_path(project_path)
1024 self.project_hash = hashlib.md5(str(self.project_path).encode()).hexdigest()[:12]
1026 # Initialize embedding provider
1027 self.provider_name = provider
1028 self.provider = get_embedding_provider(provider)
1030 # Store in user's home directory, separate by provider to avoid dimension mismatch
1031 self.db_path = Path.home() / ".stravinsky" / "vectordb" / f"{self.project_hash}_{provider}"
1032 self.db_path.mkdir(parents=True, exist_ok=True)
1034 # File lock for single-process access to ChromaDB (prevents corruption)
1035 self._lock_path = self.db_path / ".chromadb.lock"
1036 self._file_lock = None
1038 self._client = None
1039 self._collection = None
1041 # File watcher attributes
1042 self._watcher: CodebaseFileWatcher | None = None
1043 self._watcher_lock = threading.Lock()
1045 # Cancellation flag for indexing operations
1046 self._cancel_indexing = False
1047 self._cancel_lock = threading.Lock()
1049 @property
1050 def file_lock(self):
1051 """Get or create the file lock for this database.
1053 Uses filelock to ensure single-process access to ChromaDB,
1054 preventing database corruption from concurrent writes.
1055 """
1056 if self._file_lock is None:
1057 filelock = get_filelock()
1058 # Timeout of 30 seconds - if lock can't be acquired, raise error
1059 self._file_lock = filelock.FileLock(str(self._lock_path), timeout=30)
1060 return self._file_lock
1062 @property
1063 def client(self):
1064 if self._client is None:
1065 chromadb = get_chromadb()
1067 # Check for stale lock before attempting acquisition
1068 # Prevents 30s timeout from dead processes causing MCP "Connection closed" errors
1069 if self._lock_path.exists():
1070 import time
1072 lock_age = time.time() - self._lock_path.stat().st_mtime
1073 # Lock older than 60 seconds is likely from a crashed process
1074 # (Reduced from 300s to catch recently crashed processes)
1075 if lock_age > 60:
1076 logger.warning(
1077 f"Removing stale ChromaDB lock (age: {lock_age:.0f}s, path: {self._lock_path})"
1078 )
1079 try:
1080 self._lock_path.unlink(missing_ok=True)
1081 except Exception as e:
1082 logger.warning(f"Could not remove stale lock: {e}")
1084 # Acquire lock before creating client to prevent concurrent access
1085 try:
1086 with self.file_lock: # Auto-releases on exit
1087 logger.debug(f"Acquired ChromaDB lock for {self.db_path}")
1088 self._client = chromadb.PersistentClient(path=str(self.db_path))
1089 except Exception as e:
1090 logger.warning(f"Could not acquire ChromaDB lock: {e}. Proceeding without lock.")
1091 self._client = chromadb.PersistentClient(path=str(self.db_path))
1092 return self._client
1094 @property
1095 def collection(self):
1096 if self._collection is None:
1097 self._collection = self.client.get_or_create_collection(
1098 name="codebase", metadata={"hnsw:space": "cosine"}
1099 )
1100 return self._collection
1102 async def check_embedding_service(self) -> bool:
1103 """Check if the embedding provider is available."""
1104 return await self.provider.check_available()
1106 async def get_embedding(self, text: str) -> list[float]:
1107 """Get embedding vector for text using the configured provider."""
1108 return await self.provider.get_embedding(text)
1110 async def get_embeddings_batch(
1111 self, texts: list[str], max_concurrent: int = 10
1112 ) -> list[list[float]]:
1113 """Get embeddings for multiple texts with parallel execution.
1115 Uses asyncio.gather with semaphore-based concurrency control to avoid
1116 overwhelming the embedding service while maximizing throughput.
1118 Args:
1119 texts: List of text strings to embed
1120 max_concurrent: Maximum concurrent embedding requests (default: 10)
1122 Returns:
1123 List of embedding vectors in the same order as input texts.
1124 """
1125 import asyncio
1127 if not texts:
1128 return []
1130 # Use semaphore to limit concurrent requests
1131 semaphore = asyncio.Semaphore(max_concurrent)
1133 async def get_with_semaphore(text: str, index: int) -> tuple[int, list[float]]:
1134 async with semaphore:
1135 emb = await self.get_embedding(text)
1136 return (index, emb)
1138 # Launch all embedding requests concurrently (respecting semaphore)
1139 tasks = [get_with_semaphore(text, i) for i, text in enumerate(texts)]
1140 results = await asyncio.gather(*tasks)
1142 # Sort by original index to maintain order
1143 sorted_results = sorted(results, key=lambda x: x[0])
1144 return [emb for _, emb in sorted_results]
1146 def _chunk_file(self, file_path: Path) -> list[dict]:
1147 """Split a file into chunks with metadata.
1149 Uses AST-aware chunking for Python files to respect function/class
1150 boundaries. Falls back to line-based chunking for other languages.
1151 """
1152 try:
1153 content = file_path.read_text(encoding="utf-8", errors="ignore")
1154 except Exception:
1155 return []
1157 lines = content.split("\n")
1158 if len(lines) < 5: # Skip very small files
1159 return []
1161 rel_path = str(file_path.resolve().relative_to(self.project_path.resolve()))
1162 language = file_path.suffix.lstrip(".")
1164 # Use AST-aware chunking for Python files
1165 if language == "py":
1166 chunks = self._chunk_python_ast(content, rel_path, language)
1167 if chunks: # If AST parsing succeeded
1168 return chunks
1170 # Fallback: line-based chunking for other languages or if AST fails
1171 return self._chunk_by_lines(lines, rel_path, language)
1173 def _chunk_python_ast(self, content: str, rel_path: str, language: str) -> list[dict]:
1174 """Parse Python file and create chunks based on function/class boundaries.
1176 Each function, method, and class becomes its own chunk, preserving
1177 semantic boundaries for better embedding quality.
1178 """
1179 import ast
1181 try:
1182 tree = ast.parse(content)
1183 except SyntaxError:
1184 return [] # Fall back to line-based chunking
1186 lines = content.split("\n")
1187 chunks = []
1189 def get_docstring(node: ast.AST) -> str:
1190 """Extract docstring from a node if present."""
1191 if (
1192 isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
1193 and node.body
1194 ):
1195 first = node.body[0]
1196 if isinstance(first, ast.Expr) and isinstance(first.value, ast.Constant):
1197 if isinstance(first.value.value, str):
1198 return first.value.value
1199 return ""
1201 def get_decorators(
1202 node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
1203 ) -> list[str]:
1204 """Extract decorator names from a node."""
1205 decorators = []
1206 for dec in node.decorator_list:
1207 if isinstance(dec, ast.Name):
1208 decorators.append(f"@{dec.id}")
1209 elif isinstance(dec, ast.Attribute):
1210 decorators.append(f"@{ast.unparse(dec)}")
1211 elif isinstance(dec, ast.Call):
1212 if isinstance(dec.func, ast.Name):
1213 decorators.append(f"@{dec.func.id}")
1214 elif isinstance(dec.func, ast.Attribute):
1215 decorators.append(f"@{ast.unparse(dec.func)}")
1216 return decorators
1218 def get_base_classes(node: ast.ClassDef) -> list[str]:
1219 """Extract base class names from a class definition."""
1220 bases = []
1221 for base in node.bases:
1222 if isinstance(base, ast.Name):
1223 bases.append(base.id)
1224 elif isinstance(base, ast.Attribute):
1225 bases.append(ast.unparse(base))
1226 else:
1227 bases.append(ast.unparse(base))
1228 return bases
1230 def get_return_type(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
1231 """Extract return type annotation from a function."""
1232 if node.returns:
1233 return ast.unparse(node.returns)
1234 return ""
1236 def get_parameters(node: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
1237 """Extract parameter signatures from a function."""
1238 params = []
1239 for arg in node.args.args:
1240 param = arg.arg
1241 if arg.annotation:
1242 param += f": {ast.unparse(arg.annotation)}"
1243 params.append(param)
1244 return params
1246 def add_chunk(
1247 node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
1248 node_type: str,
1249 name: str,
1250 parent_class: str | None = None,
1251 ) -> None:
1252 """Add a chunk for a function/class node."""
1253 start_line = node.lineno
1254 end_line = node.end_lineno or start_line
1256 # Extract the source code for this node
1257 chunk_lines = lines[start_line - 1 : end_line]
1258 chunk_text = "\n".join(chunk_lines)
1259 content_hash = hashlib.md5(chunk_text.encode("utf-8")).hexdigest()[:12]
1261 # Skip very small chunks
1262 if len(chunk_lines) < 3:
1263 return
1265 # Build descriptive header
1266 docstring = get_docstring(node)
1267 if parent_class:
1268 header = f"File: {rel_path}\n{node_type}: {parent_class}.{name}\nLines: {start_line}-{end_line}"
1269 else:
1270 header = f"File: {rel_path}\n{node_type}: {name}\nLines: {start_line}-{end_line}"
1272 if docstring:
1273 header += f"\nDocstring: {docstring[:200]}..."
1275 document = f"{header}\n\n{chunk_text}"
1277 chunks.append(
1278 {
1279 "id": f"{rel_path}:{start_line}-{end_line}:{content_hash}",
1280 "document": document,
1281 "metadata": {
1282 "file_path": rel_path,
1283 "start_line": start_line,
1284 "end_line": end_line,
1285 "language": language,
1286 "node_type": node_type.lower(),
1287 "name": f"{parent_class}.{name}" if parent_class else name,
1288 # Structural metadata for filtering
1289 "decorators": ",".join(get_decorators(node)),
1290 "is_async": isinstance(node, ast.AsyncFunctionDef),
1291 # Class-specific metadata
1292 "base_classes": ",".join(get_base_classes(node))
1293 if isinstance(node, ast.ClassDef)
1294 else "",
1295 # Function-specific metadata
1296 "return_type": get_return_type(node)
1297 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
1298 else "",
1299 "parameters": ",".join(get_parameters(node))
1300 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
1301 else "",
1302 },
1303 }
1304 )
1306 # Walk the AST and extract functions/classes
1307 for node in ast.walk(tree):
1308 if isinstance(node, ast.ClassDef):
1309 add_chunk(node, "Class", node.name)
1310 # Also add methods as separate chunks for granular search
1311 for item in node.body:
1312 if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
1313 add_chunk(item, "Method", item.name, parent_class=node.name)
1314 elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
1315 # Only top-level functions (not methods)
1316 # Check if this function is inside a class body
1317 is_method = False
1318 for parent in ast.walk(tree):
1319 if isinstance(parent, ast.ClassDef):
1320 body = getattr(parent, "body", None)
1321 if isinstance(body, list) and node in body:
1322 is_method = True
1323 break
1324 if not is_method:
1325 add_chunk(node, "Function", node.name)
1327 # If we found no functions/classes, chunk module-level code
1328 if not chunks and len(lines) >= 5:
1329 # Add module-level chunk for imports and constants
1330 module_chunk = "\n".join(lines[: min(50, len(lines))])
1331 chunks.append(
1332 {
1333 "id": f"{rel_path}:1-{min(50, len(lines))}",
1334 "document": f"File: {rel_path}\nModule-level code\nLines: 1-{min(50, len(lines))}\n\n{module_chunk}",
1335 "metadata": {
1336 "file_path": rel_path,
1337 "start_line": 1,
1338 "end_line": min(50, len(lines)),
1339 "language": language,
1340 "node_type": "module",
1341 "name": rel_path,
1342 },
1343 }
1344 )
1346 return chunks
1348 def _chunk_by_lines(self, lines: list[str], rel_path: str, language: str) -> list[dict]:
1349 """Fallback line-based chunking with overlap."""
1350 chunks = []
1352 for i in range(0, len(lines), self.CHUNK_SIZE - self.CHUNK_OVERLAP):
1353 chunk_lines = lines[i : i + self.CHUNK_SIZE]
1354 if len(chunk_lines) < 5: # Skip tiny trailing chunks
1355 continue
1357 chunk_text = "\n".join(chunk_lines)
1358 content_hash = hashlib.md5(chunk_text.encode("utf-8")).hexdigest()[:12]
1359 start_line = i + 1
1360 end_line = i + len(chunk_lines)
1362 # Create a searchable document with context
1363 document = f"File: {rel_path}\nLines: {start_line}-{end_line}\n\n{chunk_text}"
1365 chunks.append(
1366 {
1367 "id": f"{rel_path}:{start_line}-{end_line}:{content_hash}",
1368 "document": document,
1369 "metadata": {
1370 "file_path": rel_path,
1371 "start_line": start_line,
1372 "end_line": end_line,
1373 "language": language,
1374 },
1375 }
1376 )
1378 return chunks
1380 def _load_whitelist(self) -> set[Path] | None:
1381 """Load whitelist from .stravinskyadd file if present.
1383 File format:
1384 - One path per line (relative to project root)
1385 - Lines starting with # are comments
1386 - Empty lines are ignored
1387 - Glob patterns are supported (e.g., src/**/*.py)
1388 - Directories implicitly include all files within (src/ includes src/**/*.*)
1390 Returns:
1391 Set of resolved file paths to include, or None if no whitelist file exists.
1392 """
1393 whitelist_file = self.project_path / ".stravinskyadd"
1394 if not whitelist_file.exists():
1395 return None
1397 whitelist_paths: set[Path] = set()
1398 try:
1399 content = whitelist_file.read_text(encoding="utf-8")
1400 for line in content.splitlines():
1401 line = line.strip()
1402 # Skip empty lines and comments
1403 if not line or line.startswith("#"):
1404 continue
1406 # Handle glob patterns
1407 if "*" in line or "?" in line:
1408 for matched_path in self.project_path.glob(line):
1409 if (
1410 matched_path.is_file()
1411 and matched_path.suffix.lower() in self.CODE_EXTENSIONS
1412 ):
1413 whitelist_paths.add(matched_path.resolve())
1414 else:
1415 target = self.project_path / line
1416 if target.exists():
1417 if target.is_file():
1418 # Direct file reference
1419 if target.suffix.lower() in self.CODE_EXTENSIONS:
1420 whitelist_paths.add(target.resolve())
1421 elif target.is_dir():
1422 # Directory: include all code files recursively
1423 for file_path in target.rglob("*"):
1424 if (
1425 file_path.is_file()
1426 and file_path.suffix.lower() in self.CODE_EXTENSIONS
1427 ):
1428 # Apply SKIP_DIRS even within whitelisted directories
1429 if not any(
1430 skip_dir in file_path.parts for skip_dir in self.SKIP_DIRS
1431 ):
1432 whitelist_paths.add(file_path.resolve())
1434 logger.info(f"Loaded whitelist from .stravinskyadd: {len(whitelist_paths)} files")
1435 return whitelist_paths
1437 except Exception as e:
1438 logger.warning(f"Failed to parse .stravinskyadd: {e}")
1439 return None
1441 def _get_files_to_index(self) -> list[Path]:
1442 """Get all indexable files in the project.
1444 If a .stravinskyadd whitelist file exists, ONLY those paths are indexed.
1445 Otherwise, all code files are indexed (excluding SKIP_DIRS).
1446 """
1447 # Check for whitelist mode
1448 whitelist = self._load_whitelist()
1449 if whitelist is not None:
1450 logger.info(f"Whitelist mode: indexing {len(whitelist)} files from .stravinskyadd")
1451 return sorted(whitelist) # Return sorted for deterministic order
1453 # Standard mode: crawl entire project
1454 files = []
1455 for file_path in self.project_path.rglob("*"):
1456 if file_path.is_file():
1457 # Skip files outside project boundaries (symlink traversal protection)
1458 try:
1459 resolved_file = file_path.resolve()
1460 resolved_project = self.project_path.resolve()
1462 # Check if file is under project using parent chain with samefile()
1463 # This handles macOS /var → /private/var aliasing and symlinks
1464 found = False
1465 current = resolved_file.parent
1466 while current != current.parent: # Stop at filesystem root
1467 try:
1468 if current.samefile(resolved_project):
1469 found = True
1470 break
1471 except OSError:
1472 # samefile can fail on some filesystems; try string comparison
1473 if current == resolved_project:
1474 found = True
1475 break
1476 current = current.parent
1478 if not found:
1479 continue # Outside project
1480 except (ValueError, OSError):
1481 continue # Outside project boundaries
1483 # Skip hidden files and directories
1484 if any(
1485 part.startswith(".") for part in file_path.parts[len(self.project_path.parts) :]
1486 ) and file_path.suffix not in {".md", ".txt"}: # Allow .github docs
1487 continue
1489 # Skip excluded directories
1490 if any(skip_dir in file_path.parts for skip_dir in self.SKIP_DIRS):
1491 continue
1493 # Only include code files
1494 if file_path.suffix.lower() in self.CODE_EXTENSIONS:
1495 files.append(file_path)
1497 return files
1499 def request_cancel_indexing(self) -> None:
1500 """Request cancellation of ongoing indexing operation.
1502 Sets a flag that will be checked between batches. The operation will
1503 stop gracefully after completing the current batch.
1504 """
1505 with self._cancel_lock:
1506 self._cancel_indexing = True
1507 logger.info(f"Cancellation requested for {self.project_path}")
1509 def clear_cancel_flag(self) -> None:
1510 """Clear the cancellation flag."""
1511 with self._cancel_lock:
1512 self._cancel_indexing = False
1514 def is_cancellation_requested(self) -> bool:
1515 """Check if cancellation has been requested."""
1516 with self._cancel_lock:
1517 return self._cancel_indexing
1519 async def index_codebase(self, force: bool = False) -> dict:
1520 """
1521 Index the entire codebase into the vector store.
1523 This operation can be cancelled by calling request_cancel_indexing().
1524 Cancellation happens between batches, so the current batch will complete.
1526 Args:
1527 force: If True, reindex everything. Otherwise, only index new/changed files.
1529 Returns:
1530 Statistics about the indexing operation.
1531 """
1532 import time
1534 # Clear any previous cancellation requests
1535 self.clear_cancel_flag()
1537 # Start timing
1538 start_time = time.time()
1540 print(f"🔍 SEMANTIC-INDEX: {self.project_path}", file=sys.stderr)
1542 # Notify reindex start (non-blocking)
1543 notifier = None # Initialize to avoid NameError in error handlers
1544 try:
1545 from mcp_bridge.notifications import get_notification_manager
1547 notifier = get_notification_manager()
1548 await notifier.notify_reindex_start(str(self.project_path))
1549 except Exception as e:
1550 logger.warning(f"Failed to send reindex start notification: {e}")
1552 try:
1553 if not await self.check_embedding_service():
1554 error_msg = "Embedding service not available"
1555 # Notify error
1556 try:
1557 if notifier:
1558 await notifier.notify_reindex_error(error_msg)
1559 except Exception as e:
1560 logger.warning(f"Failed to send reindex error notification: {e}")
1561 return {"error": error_msg, "indexed": 0}
1563 # Get existing document IDs
1564 existing_ids = set()
1565 try:
1566 # Only fetch IDs to minimize overhead
1567 existing = self.collection.get(include=[])
1568 existing_ids = set(existing["ids"]) if existing["ids"] else set()
1569 except Exception:
1570 pass
1572 if force:
1573 # Clear existing collection
1574 try:
1575 self.client.delete_collection("codebase")
1576 self._collection = None
1577 existing_ids = set()
1578 except Exception:
1579 pass
1581 files = self._get_files_to_index()
1582 all_chunks = []
1583 current_chunk_ids = set()
1585 # Mark: Generate all chunks for current codebase
1586 for file_path in files:
1587 chunks = self._chunk_file(file_path)
1588 all_chunks.extend(chunks)
1589 for c in chunks:
1590 current_chunk_ids.add(c["id"])
1592 # Sweep: Identify stale chunks to remove
1593 to_delete = existing_ids - current_chunk_ids
1595 # Identify new chunks to add
1596 to_add_ids = current_chunk_ids - existing_ids
1597 chunks_to_add = [c for c in all_chunks if c["id"] in to_add_ids]
1599 # Prune stale chunks
1600 if to_delete:
1601 print(f" Pruning {len(to_delete)} stale chunks...", file=sys.stderr)
1602 self.collection.delete(ids=list(to_delete))
1604 if not chunks_to_add:
1605 stats = {
1606 "indexed": 0,
1607 "pruned": len(to_delete),
1608 "total_files": len(files),
1609 "message": "No new chunks to index",
1610 "time_taken": round(time.time() - start_time, 1),
1611 }
1612 # Notify completion
1613 try:
1614 if notifier:
1615 await notifier.notify_reindex_complete(stats)
1616 except Exception as e:
1617 logger.warning(f"Failed to send reindex complete notification: {e}")
1618 return stats
1620 # Batch embed and store
1621 batch_size = 50
1622 total_indexed = 0
1624 for i in range(0, len(chunks_to_add), batch_size):
1625 # Check for cancellation between batches
1626 if self.is_cancellation_requested():
1627 print(f" ⚠️ Indexing cancelled after {total_indexed} chunks", file=sys.stderr)
1628 stats = {
1629 "indexed": total_indexed,
1630 "pruned": len(to_delete),
1631 "total_files": len(files),
1632 "db_path": str(self.db_path),
1633 "time_taken": round(time.time() - start_time, 1),
1634 "cancelled": True,
1635 "message": f"Cancelled after {total_indexed}/{len(chunks_to_add)} chunks",
1636 }
1637 # Notify cancellation
1638 try:
1639 if notifier:
1640 await notifier.notify_reindex_error(
1641 f"Indexing cancelled by user after {total_indexed} chunks"
1642 )
1643 except Exception as e:
1644 logger.warning(f"Failed to send cancellation notification: {e}")
1645 return stats
1647 batch = chunks_to_add[i : i + batch_size]
1649 documents = [c["document"] for c in batch]
1650 embeddings = await self.get_embeddings_batch(documents)
1652 self.collection.add(
1653 ids=[c["id"] for c in batch],
1654 documents=documents,
1655 embeddings=embeddings, # type: ignore[arg-type]
1656 metadatas=[c["metadata"] for c in batch],
1657 )
1658 total_indexed += len(batch)
1659 print(f" Indexed {total_indexed}/{len(chunks_to_add)} chunks...", file=sys.stderr)
1661 stats = {
1662 "indexed": total_indexed,
1663 "pruned": len(to_delete),
1664 "total_files": len(files),
1665 "db_path": str(self.db_path),
1666 "time_taken": round(time.time() - start_time, 1),
1667 }
1669 # Notify completion
1670 try:
1671 if notifier:
1672 await notifier.notify_reindex_complete(stats)
1673 except Exception as e:
1674 logger.warning(f"Failed to send reindex complete notification: {e}")
1676 return stats
1678 except Exception as e:
1679 error_msg = str(e)
1680 logger.error(f"Reindexing failed: {error_msg}")
1682 # Notify error
1683 try:
1684 if notifier:
1685 await notifier.notify_reindex_error(error_msg)
1686 except Exception as notify_error:
1687 logger.warning(f"Failed to send reindex error notification: {notify_error}")
1689 raise
1691 async def search(
1692 self,
1693 query: str,
1694 n_results: int = 10,
1695 language: str | None = None,
1696 node_type: str | None = None,
1697 decorator: str | None = None,
1698 is_async: bool | None = None,
1699 base_class: str | None = None,
1700 ) -> list[dict]:
1701 """
1702 Search the codebase with a natural language query.
1704 Args:
1705 query: Natural language search query
1706 n_results: Maximum number of results to return
1707 language: Filter by language (e.g., "py", "ts", "js")
1708 node_type: Filter by node type (e.g., "function", "class", "method")
1709 decorator: Filter by decorator (e.g., "@property", "@staticmethod")
1710 is_async: Filter by async status (True = async only, False = sync only)
1711 base_class: Filter by base class (e.g., "BaseClass")
1713 Returns:
1714 List of matching code chunks with metadata.
1715 """
1716 filters = []
1717 if language:
1718 filters.append(f"language={language}")
1719 if node_type:
1720 filters.append(f"node_type={node_type}")
1721 if decorator:
1722 filters.append(f"decorator={decorator}")
1723 if is_async is not None:
1724 filters.append(f"is_async={is_async}")
1725 if base_class:
1726 filters.append(f"base_class={base_class}")
1727 filter_str = f" [{', '.join(filters)}]" if filters else ""
1728 print(f"🔎 SEMANTIC-SEARCH: '{query[:50]}...'{filter_str}", file=sys.stderr)
1730 if not await self.check_embedding_service():
1731 return [{"error": "Embedding service not available"}]
1733 # Check if collection has documents
1734 try:
1735 count = self.collection.count()
1736 if count == 0:
1737 return [{"error": "No documents indexed", "hint": "Run index_codebase first"}]
1738 except Exception as e:
1739 return [{"error": f"Collection error: {e}"}]
1741 # Get query embedding
1742 query_embedding = await self.get_embedding(query)
1744 # Build where clause for metadata filtering
1745 where_filters = []
1746 if language:
1747 where_filters.append({"language": language})
1748 if node_type:
1749 where_filters.append({"node_type": node_type.lower()})
1750 if decorator:
1751 # ChromaDB $like for substring match in comma-separated field
1752 # Use % wildcards for pattern matching
1753 where_filters.append({"decorators": {"$like": f"%{decorator}%"}})
1754 if is_async is not None:
1755 where_filters.append({"is_async": is_async})
1756 if base_class:
1757 # Use $like for substring match
1758 where_filters.append({"base_classes": {"$like": f"%{base_class}%"}})
1760 where_clause = None
1761 if len(where_filters) == 1:
1762 where_clause = where_filters[0]
1763 elif len(where_filters) > 1:
1764 where_clause = {"$and": where_filters}
1766 # Search with optional filtering
1767 query_kwargs: dict = {
1768 "query_embeddings": [query_embedding],
1769 "n_results": n_results,
1770 "include": ["documents", "metadatas", "distances"],
1771 }
1772 if where_clause:
1773 query_kwargs["where"] = where_clause
1775 results = self.collection.query(**query_kwargs)
1777 # Format results
1778 formatted = []
1779 if results["ids"] and results["ids"][0]:
1780 for i, _doc_id in enumerate(results["ids"][0]):
1781 metadata = results["metadatas"][0][i] if results["metadatas"] else {}
1782 distance = results["distances"][0][i] if results["distances"] else 0
1783 document = results["documents"][0][i] if results["documents"] else ""
1785 # Extract just the code part (skip file/line header)
1786 code_lines = document.split("\n\n", 1)
1787 code = code_lines[1] if len(code_lines) > 1 else document
1789 formatted.append(
1790 {
1791 "file": metadata.get("file_path", "unknown"),
1792 "lines": f"{metadata.get('start_line', '?')}-{metadata.get('end_line', '?')}",
1793 "language": metadata.get("language", ""),
1794 "relevance": round(1 - distance, 3), # Convert distance to similarity
1795 "code_preview": code[:500] + "..." if len(code) > 500 else code,
1796 }
1797 )
1799 return formatted
1801 def get_stats(self) -> dict:
1802 """Get statistics about the vector store."""
1803 try:
1804 count = self.collection.count()
1805 return {
1806 "project_path": str(self.project_path),
1807 "db_path": str(self.db_path),
1808 "chunks_indexed": count,
1809 "embedding_provider": self.provider.name,
1810 "embedding_dimension": self.provider.dimension,
1811 }
1812 except Exception as e:
1813 return {"error": str(e)}
1815 def start_watching(self, debounce_seconds: float = 2.0) -> "CodebaseFileWatcher":
1816 """Start watching the project directory for file changes.
1818 Args:
1819 debounce_seconds: Time to wait before reindexing after changes (default: 2.0s)
1821 Returns:
1822 The CodebaseFileWatcher instance
1823 """
1824 with self._watcher_lock:
1825 if self._watcher is None:
1826 # Avoid circular import by importing here
1827 self._watcher = CodebaseFileWatcher(
1828 project_path=self.project_path,
1829 store=self,
1830 debounce_seconds=debounce_seconds,
1831 )
1832 self._watcher.start()
1833 else:
1834 if not self._watcher.is_running():
1835 self._watcher.start()
1836 else:
1837 logger.warning(f"Watcher for {self.project_path} is already running")
1838 return self._watcher
1840 def stop_watching(self) -> bool:
1841 """Stop watching the project directory.
1843 Returns:
1844 True if watcher was stopped, False if no watcher was active
1845 """
1846 with self._watcher_lock:
1847 if self._watcher is not None:
1848 self._watcher.stop()
1849 self._watcher = None
1850 return True
1851 return False
1853 def is_watching(self) -> bool:
1854 """Check if the project directory is being watched.
1856 Returns:
1857 True if watcher is active and running, False otherwise
1858 """
1859 with self._watcher_lock:
1860 if self._watcher is not None:
1861 return self._watcher.is_running()
1862 return False
1865# --- Module-level API for MCP tools ---
1867_stores: dict[str, CodebaseVectorStore] = {}
1868_stores_lock = threading.Lock()
1870# Module-level watcher management
1871_watchers: dict[str, "CodebaseFileWatcher"] = {}
1872_watchers_lock = threading.Lock()
1875def _cleanup_watchers():
1876 """Cleanup function to stop all watchers on exit.
1878 Registered with atexit to ensure graceful shutdown when Python exits normally.
1879 Note: This won't be called if the process is killed (SIGKILL) or crashes.
1880 """
1881 with _watchers_lock:
1882 for path, watcher in list(_watchers.items()):
1883 try:
1884 logger.debug(f"Stopping watcher for {path} on exit")
1885 watcher.stop()
1886 except Exception as e:
1887 logger.warning(f"Error stopping watcher for {path}: {e}")
1890# Register cleanup handler for graceful shutdown
1891atexit.register(_cleanup_watchers)
1894def get_store(project_path: str, provider: EmbeddingProvider = "ollama") -> CodebaseVectorStore:
1895 """Get or create a vector store for a project.
1897 Note: Cache key includes provider to prevent cross-provider conflicts
1898 (different providers have different embedding dimensions).
1899 """
1900 path = str(Path(project_path).resolve())
1901 cache_key = f"{path}:{provider}"
1902 if cache_key not in _stores:
1903 with _stores_lock:
1904 # Double-check pattern to avoid race condition
1905 if cache_key not in _stores:
1906 _stores[cache_key] = CodebaseVectorStore(path, provider)
1907 return _stores[cache_key]
1910async def semantic_search(
1911 query: str,
1912 project_path: str = ".",
1913 n_results: int = 10,
1914 language: str | None = None,
1915 node_type: str | None = None,
1916 decorator: str | None = None,
1917 is_async: bool | None = None,
1918 base_class: str | None = None,
1919 provider: EmbeddingProvider = "ollama",
1920) -> str:
1921 """
1922 Search codebase with natural language query.
1924 Args:
1925 query: Natural language search query (e.g., "find authentication logic")
1926 project_path: Path to the project root
1927 n_results: Maximum number of results to return
1928 language: Filter by language (e.g., "py", "ts", "js")
1929 node_type: Filter by node type (e.g., "function", "class", "method")
1930 decorator: Filter by decorator (e.g., "@property", "@staticmethod")
1931 is_async: Filter by async status (True = async only, False = sync only)
1932 base_class: Filter by base class (e.g., "BaseClass")
1933 provider: Embedding provider (ollama, mxbai, gemini, openai, huggingface)
1935 Returns:
1936 Formatted search results with file paths and code snippets.
1937 """
1938 store = get_store(project_path, provider)
1939 results = await store.search(
1940 query,
1941 n_results,
1942 language,
1943 node_type,
1944 decorator=decorator,
1945 is_async=is_async,
1946 base_class=base_class,
1947 )
1949 if not results:
1950 return "No results found"
1952 if "error" in results[0]:
1953 return f"Error: {results[0]['error']}\nHint: {results[0].get('hint', 'Check Ollama is running')}"
1955 lines = [f"Found {len(results)} results for: '{query}'\n"]
1956 for i, r in enumerate(results, 1):
1957 lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
1958 lines.append(f"```{r['language']}")
1959 lines.append(r["code_preview"])
1960 lines.append("```\n")
1962 return "\n".join(lines)
1965async def hybrid_search(
1966 query: str,
1967 pattern: str | None = None,
1968 project_path: str = ".",
1969 n_results: int = 10,
1970 language: str | None = None,
1971 node_type: str | None = None,
1972 decorator: str | None = None,
1973 is_async: bool | None = None,
1974 base_class: str | None = None,
1975 provider: EmbeddingProvider = "ollama",
1976) -> str:
1977 """
1978 Hybrid search combining semantic similarity with structural AST matching.
1980 Performs semantic search first, then optionally filters/boosts results
1981 that also match an ast-grep structural pattern.
1983 Args:
1984 query: Natural language search query (e.g., "find authentication logic")
1985 pattern: Optional ast-grep pattern for structural matching (e.g., "def $FUNC($$$):")
1986 project_path: Path to the project root
1987 n_results: Maximum number of results to return
1988 language: Filter by language (e.g., "py", "ts", "js")
1989 node_type: Filter by node type (e.g., "function", "class", "method")
1990 decorator: Filter by decorator (e.g., "@property", "@staticmethod")
1991 is_async: Filter by async status (True = async only, False = sync only)
1992 base_class: Filter by base class (e.g., "BaseClass")
1993 provider: Embedding provider (ollama, gemini, openai)
1995 Returns:
1996 Formatted search results with relevance scores and structural match indicators.
1997 """
1998 from mcp_bridge.tools.code_search import ast_grep_search
2000 # Get semantic results (fetch more if we're going to filter)
2001 fetch_count = n_results * 2 if pattern else n_results
2002 semantic_result = await semantic_search(
2003 query=query,
2004 project_path=project_path,
2005 n_results=fetch_count,
2006 language=language,
2007 node_type=node_type,
2008 decorator=decorator,
2009 is_async=is_async,
2010 base_class=base_class,
2011 provider=provider,
2012 )
2014 if not pattern:
2015 return semantic_result
2017 if semantic_result.startswith("Error:") or semantic_result == "No results found":
2018 return semantic_result
2020 # Get structural matches from ast-grep
2021 ast_result = await ast_grep_search(
2022 pattern=pattern,
2023 directory=project_path,
2024 language=language or "",
2025 )
2027 # Extract file paths from ast-grep results
2028 ast_files: set[str] = set()
2029 if ast_result and not ast_result.startswith("Error:") and ast_result != "No matches found":
2030 for line in ast_result.split("\n"):
2031 if line.startswith("- "):
2032 # Format: "- file.py:123"
2033 file_part = line[2:].split(":")[0]
2034 ast_files.add(file_part)
2036 if not ast_files:
2037 # No structural matches, return semantic results with note
2038 return f"{semantic_result}\n\n[Note: No structural matches for pattern '{pattern}']"
2040 # Parse semantic results and boost/annotate files that appear in both
2041 lines = []
2042 result_lines = semantic_result.split("\n")
2043 header = result_lines[0] if result_lines else ""
2044 lines.append(header.replace("results for:", "hybrid results for:"))
2045 lines.append(f"[Structural pattern: {pattern}]\n")
2047 i = 1
2048 boosted_count = 0
2049 while i < len(result_lines):
2050 line = result_lines[i]
2051 if line and (line[0].isdigit() or line.startswith("```") or line.strip()):
2052 # Check if this is a result header line (e.g., "1. file.py:10-20")
2053 if line and line[0].isdigit() and "." in line:
2054 file_part = line.split()[1].split(":")[0] if len(line.split()) > 1 else ""
2055 if file_part in ast_files:
2056 lines.append(f"{line} 🎯 [structural match]")
2057 boosted_count += 1
2058 else:
2059 lines.append(line)
2060 else:
2061 lines.append(line)
2062 else:
2063 lines.append(line)
2064 i += 1
2066 lines.append(
2067 f"\n[{boosted_count}/{len(ast_files)} semantic results also match structural pattern]"
2068 )
2070 return "\n".join(lines)
2073async def index_codebase(
2074 project_path: str = ".",
2075 force: bool = False,
2076 provider: EmbeddingProvider = "ollama",
2077) -> str:
2078 """
2079 Index a codebase for semantic search.
2081 Args:
2082 project_path: Path to the project root
2083 force: If True, reindex everything. Otherwise, only new/changed files.
2084 provider: Embedding provider - ollama (local/free), mxbai (local/free),
2085 gemini (cloud/OAuth), openai (cloud/OAuth), huggingface (cloud/token)
2087 Returns:
2088 Indexing statistics.
2089 """
2090 store = get_store(project_path, provider)
2091 stats = await store.index_codebase(force=force)
2093 if "error" in stats:
2094 return f"Error: {stats['error']}"
2096 if stats.get("cancelled"):
2097 return (
2098 f"⚠️ Indexing cancelled\n"
2099 f"Indexed {stats['indexed']} chunks from {stats['total_files']} files before cancellation\n"
2100 f"{stats.get('message', '')}"
2101 )
2103 return (
2104 f"Indexed {stats['indexed']} chunks from {stats['total_files']} files\n"
2105 f"Database: {stats.get('db_path', 'unknown')}\n"
2106 f"{stats.get('message', '')}"
2107 )
2110def cancel_indexing(
2111 project_path: str = ".",
2112 provider: EmbeddingProvider = "ollama",
2113) -> str:
2114 """
2115 Cancel an ongoing indexing operation.
2117 The cancellation happens gracefully between batches - the current batch
2118 will complete before the operation stops.
2120 Args:
2121 project_path: Path to the project root
2122 provider: Embedding provider (must match the one used for indexing)
2124 Returns:
2125 Confirmation message.
2126 """
2127 try:
2128 store = get_store(project_path, provider)
2129 store.request_cancel_indexing()
2130 return f"✅ Cancellation requested for {project_path}\nIndexing will stop after current batch completes."
2131 except Exception as e:
2132 return f"❌ Error requesting cancellation: {e}"
2135async def semantic_stats(
2136 project_path: str = ".",
2137 provider: EmbeddingProvider = "ollama",
2138) -> str:
2139 """
2140 Get statistics about the semantic search index.
2142 Args:
2143 project_path: Path to the project root
2144 provider: Embedding provider - ollama (local/free), mxbai (local/free),
2145 gemini (cloud/OAuth), openai (cloud/OAuth), huggingface (cloud/token)
2147 Returns:
2148 Index statistics.
2149 """
2150 store = get_store(project_path, provider)
2151 stats = store.get_stats()
2153 if "error" in stats:
2154 return f"Error: {stats['error']}"
2156 return (
2157 f"Project: {stats['project_path']}\n"
2158 f"Database: {stats['db_path']}\n"
2159 f"Chunks indexed: {stats['chunks_indexed']}\n"
2160 f"Embedding provider: {stats['embedding_provider']} ({stats['embedding_dimension']} dims)"
2161 )
2164def delete_index(
2165 project_path: str = ".",
2166 provider: EmbeddingProvider | None = None,
2167 delete_all: bool = False,
2168) -> str:
2169 """
2170 Delete semantic search index for a project.
2172 Args:
2173 project_path: Path to the project root
2174 provider: Embedding provider (if None and delete_all=False, deletes all providers for this project)
2175 delete_all: If True, delete ALL indexes for ALL projects (ignores project_path and provider)
2177 Returns:
2178 Confirmation message with deleted paths.
2179 """
2180 import shutil
2182 vectordb_base = Path.home() / ".stravinsky" / "vectordb"
2184 if not vectordb_base.exists():
2185 return "✅ No semantic search indexes found (vectordb directory doesn't exist)"
2187 if delete_all:
2188 # Delete entire vectordb directory
2189 try:
2190 shutil.rmtree(vectordb_base)
2191 return "✅ Deleted all semantic search indexes for all projects"
2192 except Exception as e:
2193 return f"❌ Error deleting all indexes: {e}"
2195 # Generate project hash
2196 project_path_resolved = Path(project_path).resolve()
2197 project_hash = hashlib.md5(str(project_path_resolved).encode()).hexdigest()[:12]
2199 deleted = []
2200 errors = []
2202 if provider:
2203 # Delete specific provider index for this project
2204 index_path = vectordb_base / f"{project_hash}_{provider}"
2205 if index_path.exists():
2206 try:
2207 shutil.rmtree(index_path)
2208 deleted.append(str(index_path))
2209 except Exception as e:
2210 errors.append(f"{provider}: {e}")
2211 else:
2212 errors.append(f"{provider}: Index not found")
2213 else:
2214 # Delete all provider indexes for this project
2215 providers: list[EmbeddingProvider] = ["ollama", "mxbai", "gemini", "openai", "huggingface"]
2216 for prov in providers:
2217 index_path = vectordb_base / f"{project_hash}_{prov}"
2218 if index_path.exists():
2219 try:
2220 shutil.rmtree(index_path)
2221 deleted.append(str(index_path))
2222 except Exception as e:
2223 errors.append(f"{prov}: {e}")
2225 if not deleted and not errors:
2226 return f"⚠️ No indexes found for project: {project_path_resolved}\nProject hash: {project_hash}"
2228 result = []
2229 if deleted:
2230 result.append(f"✅ Deleted {len(deleted)} index(es):")
2231 for path in deleted:
2232 result.append(f" - {path}")
2233 if errors:
2234 result.append(f"\n❌ Errors ({len(errors)}):")
2235 for error in errors:
2236 result.append(f" - {error}")
2238 return "\n".join(result)
2241async def semantic_health(project_path: str = ".", provider: EmbeddingProvider = "ollama") -> str:
2242 """Check health of semantic search system."""
2243 store = get_store(project_path, provider)
2245 status = []
2247 # Check Provider
2248 try:
2249 is_avail = await store.check_embedding_service()
2250 status.append(
2251 f"Provider ({store.provider.name}): {'✅ Online' if is_avail else '❌ Offline'}"
2252 )
2253 except Exception as e:
2254 status.append(f"Provider ({store.provider.name}): ❌ Error - {e}")
2256 # Check DB
2257 try:
2258 count = store.collection.count()
2259 status.append(f"Vector DB: ✅ Online ({count} documents)")
2260 except Exception as e:
2261 status.append(f"Vector DB: ❌ Error - {e}")
2263 return "\n".join(status)
2266# ========================
2267# FILE WATCHER MANAGEMENT
2268# ========================
2271async def start_file_watcher(
2272 project_path: str,
2273 provider: EmbeddingProvider = "ollama",
2274 debounce_seconds: float = 2.0,
2275) -> "CodebaseFileWatcher":
2276 """Start watching a project directory for file changes.
2278 If an index exists, automatically performs an incremental reindex to catch up
2279 on any changes that happened while the watcher was not running.
2281 Args:
2282 project_path: Path to the project root
2283 provider: Embedding provider to use for reindexing
2284 debounce_seconds: Time to wait before reindexing after changes
2286 Returns:
2287 The started CodebaseFileWatcher instance
2288 """
2289 path = str(Path(project_path).resolve())
2290 with _watchers_lock:
2291 if path not in _watchers:
2292 store = get_store(project_path, provider)
2294 # Check if index exists - CRITICAL: Must have index before watching
2295 try:
2296 stats = store.get_stats()
2297 chunks_indexed = stats.get("chunks_indexed", 0)
2298 if chunks_indexed == 0:
2299 raise ValueError(
2300 f"No semantic index found for '{path}'. "
2301 f"Run semantic_index(project_path='{path}', provider='{provider}') "
2302 f"before starting the file watcher."
2303 )
2305 # Index exists - catch up on any missed changes
2306 print("📋 Catching up on changes since last index...", file=sys.stderr)
2307 await store.index_codebase(force=False)
2308 print("✅ Index updated, starting file watcher", file=sys.stderr)
2310 except ValueError:
2311 # Re-raise ValueError (our intentional error)
2312 raise
2313 except Exception as e:
2314 # Collection doesn't exist or other error
2315 raise ValueError(
2316 f"No semantic index found for '{path}'. "
2317 f"Run semantic_index(project_path='{path}', provider='{provider}') "
2318 f"before starting the file watcher."
2319 ) from e
2321 watcher = store.start_watching(debounce_seconds=debounce_seconds)
2322 _watchers[path] = watcher
2323 else:
2324 watcher = _watchers[path]
2325 if not watcher.is_running():
2326 watcher.start()
2327 return _watchers[path]
2330def stop_file_watcher(project_path: str) -> bool:
2331 """Stop watching a project directory.
2333 Args:
2334 project_path: Path to the project root
2336 Returns:
2337 True if watcher was stopped, False if no watcher was active
2338 """
2339 path = str(Path(project_path).resolve())
2340 with _watchers_lock:
2341 if path in _watchers:
2342 watcher = _watchers[path]
2343 watcher.stop()
2344 del _watchers[path]
2345 return True
2346 return False
2349def get_file_watcher(project_path: str) -> "CodebaseFileWatcher | None":
2350 """Get an active file watcher for a project.
2352 Args:
2353 project_path: Path to the project root
2355 Returns:
2356 The CodebaseFileWatcher if active, None otherwise
2357 """
2358 path = str(Path(project_path).resolve())
2359 with _watchers_lock:
2360 watcher = _watchers.get(path)
2361 if watcher is not None and watcher.is_running():
2362 return watcher
2363 return None
2366def list_file_watchers() -> list[dict]:
2367 """List all active file watchers.
2369 Returns:
2370 List of dicts with watcher info (project_path, debounce_seconds, provider, status)
2371 """
2372 with _watchers_lock:
2373 watchers_info = []
2374 for path, watcher in _watchers.items():
2375 watchers_info.append(
2376 {
2377 "project_path": path,
2378 "debounce_seconds": watcher.debounce_seconds,
2379 "provider": watcher.store.provider_name,
2380 "status": "running" if watcher.is_running() else "stopped",
2381 }
2382 )
2383 return watchers_info
2386# ========================
2387# MULTI-QUERY EXPANSION & DECOMPOSITION
2388# ========================
2391async def _expand_query_with_llm(query: str, num_variations: int = 3) -> list[str]:
2392 """
2393 Use LLM to rephrase a query into multiple semantic variations.
2395 For example: "database connection" -> ["SQLAlchemy engine setup",
2396 "connect to postgres", "db session management"]
2398 Args:
2399 query: Original search query
2400 num_variations: Number of variations to generate (default: 3)
2402 Returns:
2403 List of query variations including the original
2404 """
2405 from mcp_bridge.tools.model_invoke import invoke_gemini
2407 prompt = f"""You are a code search query expander. Given a search query, generate {num_variations} alternative phrasings that would help find relevant code.
2409Original query: "{query}"
2411Generate {num_variations} alternative queries that:
24121. Use different technical terminology (e.g., "database" -> "SQLAlchemy", "ORM", "connection pool")
24132. Reference specific implementations or patterns
24143. Include related concepts that might appear in code
2416Return ONLY the alternative queries, one per line. No numbering, no explanations.
2417Example output for "database connection":
2418SQLAlchemy engine configuration
2419postgres connection setup
2420db session factory pattern"""
2422 try:
2423 result = await invoke_gemini(
2424 token_store=TokenStore(),
2425 prompt=prompt,
2426 model="gemini-2.0-flash",
2427 temperature=0.7,
2428 max_tokens=200,
2429 )
2431 # Parse variations from response
2432 variations = [line.strip() for line in result.strip().split("\n") if line.strip()]
2433 # Always include original query first
2434 all_queries = [query] + variations[:num_variations]
2435 return all_queries
2437 except Exception as e:
2438 logger.warning(f"Query expansion failed: {e}, using original query only")
2439 return [query]
2442async def _decompose_query_with_llm(query: str) -> list[str]:
2443 """
2444 Break a complex query into smaller, focused sub-questions.
2446 For example: "Initialize the DB and then create a user model" ->
2447 ["database initialization", "user model definition"]
2449 Args:
2450 query: Complex search query
2452 Returns:
2453 List of sub-queries, or [query] if decomposition not needed
2454 """
2455 from mcp_bridge.tools.model_invoke import invoke_gemini
2457 prompt = f"""You are a code search query analyzer. Determine if this query should be broken into sub-queries.
2459Query: "{query}"
2461If the query contains multiple distinct concepts (connected by "and", "then", "also", etc.),
2462break it into separate focused sub-queries.
2464If the query is already focused on a single concept, return just that query.
2466Return ONLY the sub-queries, one per line. No numbering, no explanations.
2468Examples:
2469- "Initialize the DB and then create a user model" ->
2470database initialization
2471user model definition
2473- "authentication logic" ->
2474authentication logic"""
2476 try:
2477 result = await invoke_gemini(
2478 token_store=TokenStore(),
2479 prompt=prompt,
2480 model="gemini-2.0-flash",
2481 temperature=0.3, # Lower temperature for more consistent decomposition
2482 max_tokens=150,
2483 )
2485 # Parse sub-queries from response
2486 sub_queries = [line.strip() for line in result.strip().split("\n") if line.strip()]
2487 return sub_queries if sub_queries else [query]
2489 except Exception as e:
2490 logger.warning(f"Query decomposition failed: {e}, using original query")
2491 return [query]
2494def _aggregate_results(
2495 all_results: list[list[dict]],
2496 n_results: int = 10,
2497) -> list[dict]:
2498 """
2499 Aggregate and deduplicate results from multiple queries.
2501 Uses reciprocal rank fusion to combine relevance scores from different queries.
2503 Args:
2504 all_results: List of result lists from different queries
2505 n_results: Maximum number of results to return
2507 Returns:
2508 Deduplicated and re-ranked results
2509 """
2510 # Track seen files to avoid duplicates
2511 seen_files: dict[str, dict] = {} # file:lines -> result with best score
2512 file_scores: dict[str, float] = {} # file:lines -> aggregated score
2514 # Reciprocal Rank Fusion constant
2515 k = 60
2517 for _query_idx, results in enumerate(all_results):
2518 for rank, result in enumerate(results):
2519 file_key = f"{result.get('file', '')}:{result.get('lines', '')}"
2521 # RRF score contribution
2522 rrf_score = 1 / (k + rank + 1)
2524 if file_key not in seen_files:
2525 seen_files[file_key] = result.copy()
2526 file_scores[file_key] = rrf_score
2527 else:
2528 # Aggregate scores
2529 file_scores[file_key] += rrf_score
2530 # Keep higher original relevance if available
2531 if result.get("relevance", 0) > seen_files[file_key].get("relevance", 0):
2532 seen_files[file_key] = result.copy()
2534 # Sort by aggregated score and return top N
2535 sorted_keys = sorted(file_scores.keys(), key=lambda k: file_scores[k], reverse=True)
2537 aggregated = []
2538 for key in sorted_keys[:n_results]:
2539 result = seen_files[key]
2540 # Update relevance to reflect aggregated score (normalized)
2541 max_score = max(file_scores.values()) if file_scores else 1
2542 result["relevance"] = round(file_scores[key] / max_score, 3)
2543 aggregated.append(result)
2545 return aggregated
2548async def multi_query_search(
2549 query: str,
2550 project_path: str = ".",
2551 n_results: int = 10,
2552 num_expansions: int = 3,
2553 language: str | None = None,
2554 node_type: str | None = None,
2555 provider: EmbeddingProvider = "ollama",
2556) -> str:
2557 """
2558 Search with LLM-expanded query variations for better recall.
2560 Rephrases the query into multiple semantic variations, searches for each,
2561 and aggregates results using reciprocal rank fusion.
2563 Args:
2564 query: Natural language search query
2565 project_path: Path to the project root
2566 n_results: Maximum number of results to return
2567 num_expansions: Number of query variations to generate (default: 3)
2568 language: Filter by language (e.g., "py", "ts")
2569 node_type: Filter by node type (e.g., "function", "class")
2570 provider: Embedding provider
2572 Returns:
2573 Formatted search results with relevance scores.
2574 """
2575 import asyncio
2577 print(f"🔍 MULTI-QUERY: Expanding '{query[:50]}...'", file=sys.stderr)
2579 # Get query expansions
2580 expanded_queries = await _expand_query_with_llm(query, num_expansions)
2581 print(f" Generated {len(expanded_queries)} query variations", file=sys.stderr)
2583 # Get store once
2584 store = get_store(project_path, provider)
2586 # Search with all queries in parallel
2587 async def search_single(q: str) -> list[dict]:
2588 return await store.search(
2589 q,
2590 n_results=n_results, # Get full results for each query
2591 language=language,
2592 node_type=node_type,
2593 )
2595 all_results = await asyncio.gather(*[search_single(q) for q in expanded_queries])
2597 # Filter out error results
2598 valid_results = [r for r in all_results if r and "error" not in r[0]]
2600 if not valid_results:
2601 if all_results and all_results[0] and "error" in all_results[0][0]:
2602 return f"Error: {all_results[0][0]['error']}"
2603 return "No results found"
2605 # Aggregate results
2606 aggregated = _aggregate_results(valid_results, n_results)
2608 if not aggregated:
2609 return "No results found"
2611 # Format output
2612 lines = [f"Found {len(aggregated)} results for multi-query expansion of: '{query}'"]
2613 lines.append(
2614 f"[Expanded to: {', '.join(q[:30] + '...' if len(q) > 30 else q for q in expanded_queries)}]\n"
2615 )
2617 for i, r in enumerate(aggregated, 1):
2618 lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
2619 lines.append(f"```{r.get('language', '')}")
2620 lines.append(r.get("code_preview", ""))
2621 lines.append("```\n")
2623 return "\n".join(lines)
2626async def decomposed_search(
2627 query: str,
2628 project_path: str = ".",
2629 n_results: int = 10,
2630 language: str | None = None,
2631 node_type: str | None = None,
2632 provider: EmbeddingProvider = "ollama",
2633) -> str:
2634 """
2635 Search by decomposing complex queries into focused sub-questions.
2637 Breaks multi-part queries like "Initialize the DB and create a user model"
2638 into separate searches, returning organized results for each part.
2640 Args:
2641 query: Complex search query (may contain multiple concepts)
2642 project_path: Path to the project root
2643 n_results: Maximum results per sub-query
2644 language: Filter by language
2645 node_type: Filter by node type
2646 provider: Embedding provider
2648 Returns:
2649 Formatted results organized by sub-question.
2650 """
2651 import asyncio
2653 print(f"🔍 DECOMPOSED-SEARCH: Analyzing '{query[:50]}...'", file=sys.stderr)
2655 # Decompose query
2656 sub_queries = await _decompose_query_with_llm(query)
2657 print(f" Decomposed into {len(sub_queries)} sub-queries", file=sys.stderr)
2659 if len(sub_queries) == 1 and sub_queries[0] == query:
2660 # No decomposition needed, use regular search
2661 return await semantic_search(
2662 query=query,
2663 project_path=project_path,
2664 n_results=n_results,
2665 language=language,
2666 node_type=node_type,
2667 provider=provider,
2668 )
2670 # Get store once
2671 store = get_store(project_path, provider)
2673 # Search each sub-query in parallel
2674 async def search_sub(q: str) -> tuple[str, list[dict]]:
2675 results = await store.search(
2676 q,
2677 n_results=n_results // len(sub_queries) + 2, # Distribute results
2678 language=language,
2679 node_type=node_type,
2680 )
2681 return (q, results)
2683 sub_results = await asyncio.gather(*[search_sub(q) for q in sub_queries])
2685 # Format output with sections for each sub-query
2686 lines = [f"Decomposed search for: '{query}'"]
2687 lines.append(f"[Split into {len(sub_queries)} sub-queries]\n")
2689 total_results = 0
2690 for sub_query, results in sub_results:
2691 lines.append(f"### {sub_query}")
2693 if not results or (results and "error" in results[0]):
2694 lines.append(" No results found\n")
2695 continue
2697 for i, r in enumerate(results[:5], 1): # Limit per sub-query
2698 lines.append(f" {i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
2699 # Shorter preview for decomposed results
2700 preview = r.get("code_preview", "")[:200]
2701 if len(r.get("code_preview", "")) > 200:
2702 preview += "..."
2703 lines.append(f" ```{r.get('language', '')}")
2704 lines.append(f" {preview}")
2705 lines.append(" ```")
2706 total_results += 1
2707 lines.append("")
2709 lines.append(f"[Total: {total_results} results across {len(sub_queries)} sub-queries]")
2711 return "\n".join(lines)
2714async def enhanced_search(
2715 query: str,
2716 project_path: str = ".",
2717 n_results: int = 10,
2718 mode: str = "auto",
2719 language: str | None = None,
2720 node_type: str | None = None,
2721 provider: EmbeddingProvider = "ollama",
2722) -> str:
2723 """
2724 Unified enhanced search combining expansion and decomposition.
2726 Automatically selects the best strategy based on query complexity:
2727 - Simple queries: Multi-query expansion for better recall
2728 - Complex queries: Decomposition + expansion for comprehensive coverage
2730 Args:
2731 query: Search query (simple or complex)
2732 project_path: Path to the project root
2733 n_results: Maximum number of results
2734 mode: Search mode - "auto", "expand", "decompose", or "both"
2735 language: Filter by language
2736 node_type: Filter by node type
2737 provider: Embedding provider
2739 Returns:
2740 Formatted search results.
2741 """
2742 # Use classifier for intelligent mode selection
2743 classification = classify_query(query)
2744 logger.debug(
2745 f"Query classified as {classification.category.value} "
2746 f"(confidence: {classification.confidence:.2f}, suggested: {classification.suggested_tool})"
2747 )
2749 # Determine mode based on classification
2750 if mode == "auto":
2751 # HYBRID → decompose (complex multi-part queries)
2752 # SEMANTIC → expand (conceptual queries benefit from variations)
2753 # PATTERN/STRUCTURAL → expand (simple queries, quick path)
2754 mode = "decompose" if classification.category == QueryCategory.HYBRID else "expand"
2756 if mode == "decompose":
2757 return await decomposed_search(
2758 query=query,
2759 project_path=project_path,
2760 n_results=n_results,
2761 language=language,
2762 node_type=node_type,
2763 provider=provider,
2764 )
2765 elif mode == "expand":
2766 return await multi_query_search(
2767 query=query,
2768 project_path=project_path,
2769 n_results=n_results,
2770 language=language,
2771 node_type=node_type,
2772 provider=provider,
2773 )
2774 elif mode == "both":
2775 # Decompose first, then expand each sub-query
2776 sub_queries = await _decompose_query_with_llm(query)
2778 all_results: list[list[dict]] = []
2779 store = get_store(project_path, provider)
2781 for sub_q in sub_queries:
2782 # Expand each sub-query
2783 expanded = await _expand_query_with_llm(sub_q, num_variations=2)
2784 for exp_q in expanded:
2785 results = await store.search(
2786 exp_q,
2787 n_results=5,
2788 language=language,
2789 node_type=node_type,
2790 )
2791 if results and "error" not in results[0]:
2792 all_results.append(results)
2794 aggregated = _aggregate_results(all_results, n_results)
2796 if not aggregated:
2797 return "No results found"
2799 lines = [f"Enhanced search (decompose+expand) for: '{query}'"]
2800 lines.append(f"[{len(sub_queries)} sub-queries × expansions]\n")
2802 for i, r in enumerate(aggregated, 1):
2803 lines.append(f"{i}. {r['file']}:{r['lines']} (relevance: {r['relevance']})")
2804 lines.append(f"```{r.get('language', '')}")
2805 lines.append(r.get("code_preview", ""))
2806 lines.append("```\n")
2808 return "\n".join(lines)
2810 else:
2811 return f"Unknown mode: {mode}. Use 'auto', 'expand', 'decompose', or 'both'"
2814# ========================
2815# FILE WATCHER IMPLEMENTATION
2816# ========================
2819class DedicatedIndexingWorker:
2820 """Single-threaded worker for all indexing operations.
2822 Prevents concurrent indexing by serializing all operations through a queue.
2823 Uses asyncio.run() for each operation to avoid event loop reuse issues.
2824 """
2826 def __init__(self, store: "CodebaseVectorStore"):
2827 """Initialize the indexing worker.
2829 Args:
2830 store: CodebaseVectorStore instance for reindexing
2831 """
2832 import queue
2834 self.store = store
2835 self._queue: queue.Queue = queue.Queue(maxsize=1) # Max 1 pending request (debouncing)
2836 self._thread: threading.Thread | None = None
2837 self._shutdown = threading.Event()
2838 self._log_file = Path.home() / ".stravinsky" / "logs" / "file_watcher.log"
2839 self._log_file.parent.mkdir(parents=True, exist_ok=True)
2841 def start(self) -> None:
2842 """Start the worker thread."""
2843 if self._thread is not None and self._thread.is_alive():
2844 logger.warning("Indexing worker already running")
2845 return
2847 self._shutdown.clear()
2848 self._thread = threading.Thread(target=self._run_worker, daemon=False, name="IndexingWorker")
2849 self._thread.start()
2850 logger.info(f"Started indexing worker for {self.store.project_path}")
2852 def _log_error(self, msg: str, exc: Exception | None = None):
2853 """Write error to log file with timestamp and full traceback."""
2854 import traceback
2855 from datetime import datetime
2857 timestamp = datetime.now().isoformat()
2858 try:
2859 with open(self._log_file, "a") as f:
2860 f.write(f"\n{'='*80}\n")
2861 f.write(f"[{timestamp}] {msg}\n")
2862 if exc:
2863 f.write(f"Exception: {type(exc).__name__}: {exc}\n")
2864 f.write(traceback.format_exc())
2865 f.write(f"{'='*80}\n")
2866 except Exception as log_exc:
2867 logger.error(f"Failed to write to log file: {log_exc}")
2868 logger.error(f"{msg} (logged to {self._log_file})")
2870 def _run_worker(self) -> None:
2871 """Worker thread entry point - processes queue with asyncio.run() per operation."""
2872 import queue
2874 self._log_error(f"🟢 File watcher started for {self.store.project_path}")
2876 try:
2877 while not self._shutdown.is_set():
2878 try:
2879 # Wait for reindex request (blocking with timeout)
2880 self._queue.get(timeout=0.5)
2881 self._queue.task_done()
2883 # Use asyncio.run() for each operation (creates fresh loop)
2884 # This avoids "event loop already running" errors
2885 try:
2886 asyncio.run(self._do_reindex())
2887 self._log_error(f"✅ Reindex completed for {self.store.project_path}")
2888 except Exception as e:
2889 self._log_error(f"⚠️ Reindex failed for {self.store.project_path}", e)
2891 except queue.Empty:
2892 continue # No work, check shutdown flag
2893 except Exception as e:
2894 self._log_error(f"⚠️ Queue processing error for {self.store.project_path}", e)
2896 except Exception as e:
2897 self._log_error(f"⚠️ Worker thread crashed for {self.store.project_path}", e)
2898 finally:
2899 self._log_error(f"🔴 File watcher stopped for {self.store.project_path}")
2901 async def _do_reindex(self) -> None:
2902 """Execute reindex with retry logic for ALL error types."""
2903 import sqlite3
2905 from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
2907 @retry(
2908 stop=stop_after_attempt(3),
2909 wait=wait_exponential(multiplier=1, min=2, max=10),
2910 retry=retry_if_exception_type((
2911 httpx.HTTPError,
2912 ConnectionError,
2913 TimeoutError,
2914 sqlite3.OperationalError, # Database locked
2915 OSError, # File system errors
2916 )),
2917 reraise=True,
2918 )
2919 async def _indexed():
2920 await self.store.index_codebase(force=False)
2922 await _indexed()
2924 def request_reindex(self, files: list[Path]) -> None:
2925 """Request reindex from any thread (thread-safe).
2927 Args:
2928 files: List of files that changed (for logging only)
2929 """
2930 import queue
2932 try:
2933 # Non-blocking put - drops if queue full (natural debouncing)
2934 self._queue.put_nowait("reindex")
2935 logger.debug(f"📥 Queued reindex for {len(files)} files: {[f.name for f in files[:5]]}")
2936 except queue.Full:
2937 # Already have pending reindex - this is fine (debouncing)
2938 logger.debug(f"Reindex already queued, skipping {len(files)} files")
2940 def shutdown(self) -> None:
2941 """Graceful shutdown of worker thread."""
2942 if self._shutdown.is_set():
2943 return # Already shutting down
2945 self._shutdown.set()
2946 if self._thread is not None and self._thread.is_alive():
2947 self._thread.join(timeout=10) # Wait up to 10 seconds
2948 if self._thread.is_alive():
2949 self._log_error("⚠️ Worker thread failed to stop within timeout")
2950 self._thread = None
2951 logger.info("Indexing worker shut down")
2954class CodebaseFileWatcher:
2955 """Watch a project directory for file changes and trigger reindexing.
2957 Features:
2958 - Watches for file create, modify, delete, move events
2959 - Filters to .py files only
2960 - Skips hidden files and directories (., .git, __pycache__, venv, etc.)
2961 - Debounces rapid changes to batch them into a single reindex
2962 - Thread-safe with daemon threads for clean shutdown
2963 - Integrates with CodebaseVectorStore for incremental indexing
2964 - Uses dedicated worker thread to prevent concurrent indexing
2965 """
2967 # Default debounce time in seconds
2968 DEFAULT_DEBOUNCE_SECONDS = 2.0
2970 def __init__(
2971 self,
2972 project_path: Path | str,
2973 store: CodebaseVectorStore,
2974 debounce_seconds: float = DEFAULT_DEBOUNCE_SECONDS,
2975 ):
2976 """Initialize the file watcher.
2978 Args:
2979 project_path: Path to the project root to watch
2980 store: CodebaseVectorStore instance for reindexing
2981 debounce_seconds: Time to wait before reindexing after changes (default: 2.0s)
2982 """
2983 self.project_path = Path(project_path).resolve()
2984 self.store = store
2985 self.debounce_seconds = debounce_seconds
2987 # Observer and handler for watchdog
2988 self._observer = None
2989 self._event_handler = None
2991 # Thread safety
2992 self._lock = threading.Lock()
2993 self._running = False
2995 # Debouncing
2996 self._pending_reindex_timer: threading.Timer | None = None
2997 self._pending_files: set[Path] = set()
2998 self._pending_lock = threading.Lock()
3000 # Dedicated indexing worker (prevents concurrent access)
3001 self._indexing_worker = DedicatedIndexingWorker(store)
3003 def start(self) -> None:
3004 """Start watching the project directory.
3006 Creates and starts a watchdog observer in a daemon thread.
3007 Also starts the dedicated indexing worker thread.
3008 """
3009 with self._lock:
3010 if self._running:
3011 logger.warning(f"Watcher for {self.project_path} is already running")
3012 return
3014 try:
3015 # Start indexing worker first (must be running before file events arrive)
3016 self._indexing_worker.start()
3018 watchdog = get_watchdog()
3019 Observer = watchdog["Observer"]
3021 # Create event handler class and instantiate
3022 FileChangeHandler = _create_file_change_handler_class()
3023 self._event_handler = FileChangeHandler(
3024 project_path=self.project_path,
3025 watcher=self,
3026 )
3028 # Create and start observer (daemon mode for clean shutdown)
3029 self._observer = Observer()
3030 self._observer.daemon = True
3031 self._observer.schedule(
3032 self._event_handler,
3033 str(self.project_path),
3034 recursive=True,
3035 )
3036 self._observer.start()
3037 self._running = True
3038 logger.info(f"File watcher started for {self.project_path}")
3040 except Exception as e:
3041 logger.error(f"Failed to start file watcher: {e}")
3042 self._running = False
3043 # Clean up worker if observer failed
3044 self._indexing_worker.shutdown()
3045 raise
3047 def stop(self) -> None:
3048 """Stop watching the project directory.
3050 Cancels any pending reindex timers, stops the observer, and shuts down the indexing worker.
3051 """
3052 with self._lock:
3053 # Cancel pending reindex
3054 if self._pending_reindex_timer is not None:
3055 self._pending_reindex_timer.cancel()
3056 self._pending_reindex_timer = None
3058 # Stop observer
3059 if self._observer is not None:
3060 self._observer.stop()
3061 self._observer.join(timeout=5) # Wait up to 5 seconds for shutdown
3062 self._observer = None
3064 # Shutdown indexing worker
3065 self._indexing_worker.shutdown()
3067 self._event_handler = None
3068 self._running = False
3069 logger.info(f"File watcher stopped for {self.project_path}")
3071 def is_running(self) -> bool:
3072 """Check if the watcher is currently running.
3074 Returns:
3075 True if watcher is active, False otherwise
3076 """
3077 with self._lock:
3078 return self._running and self._observer is not None and self._observer.is_alive()
3080 def _on_file_changed(self, file_path: Path) -> None:
3081 """Called when a file changes (internal use by _FileChangeHandler).
3083 Accumulates files and triggers debounced reindex.
3085 Args:
3086 file_path: Path to the changed file
3087 """
3088 with self._pending_lock:
3089 self._pending_files.add(file_path)
3091 # Cancel previous timer
3092 if self._pending_reindex_timer is not None:
3093 self._pending_reindex_timer.cancel()
3095 # Start new timer
3096 self._pending_reindex_timer = self._create_debounce_timer()
3097 self._pending_reindex_timer.start()
3099 def _create_debounce_timer(self) -> threading.Timer:
3100 """Create a new debounce timer for reindexing.
3102 Returns:
3103 A threading.Timer configured for debounce reindexing
3104 """
3105 return threading.Timer(
3106 self.debounce_seconds,
3107 self._trigger_reindex,
3108 )
3110 def _trigger_reindex(self) -> None:
3111 """Trigger reindexing of accumulated changed files.
3113 This is called after the debounce period expires. Delegates to the
3114 dedicated indexing worker to prevent concurrent access.
3115 """
3116 with self._pending_lock:
3117 if not self._pending_files:
3118 self._pending_reindex_timer = None
3119 return
3121 files_to_index = list(self._pending_files)
3122 self._pending_files.clear()
3123 self._pending_reindex_timer = None
3125 # Delegate to dedicated worker (prevents concurrent indexing)
3126 self._indexing_worker.request_reindex(files_to_index)
3129def _create_file_change_handler_class():
3130 """Create FileChangeHandler class that inherits from FileSystemEventHandler.
3132 This is a factory function that creates the handler class dynamically
3133 after watchdog is imported, allowing for lazy loading.
3134 """
3135 watchdog = get_watchdog()
3136 FileSystemEventHandler = watchdog["FileSystemEventHandler"]
3138 class _FileChangeHandler(FileSystemEventHandler):
3139 """Watchdog event handler for file system changes.
3141 Detects file create, modify, delete, and move events, filters them,
3142 and notifies the watcher of relevant changes.
3143 """
3145 def __init__(self, project_path: Path, watcher: CodebaseFileWatcher):
3146 """Initialize the event handler.
3148 Args:
3149 project_path: Root path of the project being watched
3150 watcher: CodebaseFileWatcher instance to notify
3151 """
3152 super().__init__()
3153 self.project_path = project_path
3154 self.watcher = watcher
3156 def on_created(self, event) -> None:
3157 """Called when a file is created."""
3158 if not event.is_directory and self._should_index_file(event.src_path):
3159 logger.debug(f"File created: {event.src_path}")
3160 self.watcher._on_file_changed(Path(event.src_path))
3162 def on_modified(self, event) -> None:
3163 """Called when a file is modified."""
3164 if not event.is_directory and self._should_index_file(event.src_path):
3165 logger.debug(f"File modified: {event.src_path}")
3166 self.watcher._on_file_changed(Path(event.src_path))
3168 def on_deleted(self, event) -> None:
3169 """Called when a file is deleted."""
3170 if not event.is_directory and self._should_index_file(event.src_path):
3171 logger.debug(f"File deleted: {event.src_path}")
3172 self.watcher._on_file_changed(Path(event.src_path))
3174 def on_moved(self, event) -> None:
3175 """Called when a file is moved."""
3176 if not event.is_directory:
3177 # Check destination path
3178 if self._should_index_file(event.dest_path):
3179 logger.debug(f"File moved: {event.src_path} -> {event.dest_path}")
3180 self.watcher._on_file_changed(Path(event.dest_path))
3181 # Also check source path (for deletion case)
3182 elif self._should_index_file(event.src_path):
3183 logger.debug(f"File moved out: {event.src_path}")
3184 self.watcher._on_file_changed(Path(event.src_path))
3186 def _should_index_file(self, file_path: str) -> bool:
3187 """Check if a file should trigger reindexing.
3189 Filters based on:
3190 - File extension (.py only)
3191 - Hidden files and directories (starting with .)
3192 - Skip directories (venv, __pycache__, .git, node_modules, etc.)
3194 Args:
3195 file_path: Path to the file to check
3197 Returns:
3198 True if file should trigger reindexing, False otherwise
3199 """
3200 path = Path(file_path)
3202 # Only .py files
3203 if path.suffix != ".py":
3204 return False
3206 # Skip hidden files
3207 if path.name.startswith("."):
3208 return False
3210 # Check for skip directories in the path
3211 for part in path.parts:
3212 if part.startswith("."): # Hidden directories like .git, .venv
3213 return False
3214 if part in {"__pycache__", "venv", "env", "node_modules"}:
3215 return False
3217 # File is within project (resolve both paths to handle symlinks)
3218 try:
3219 path.resolve().relative_to(self.project_path)
3220 return True
3221 except ValueError:
3222 # File is outside project
3223 return False
3225 return _FileChangeHandler
3228# ========================
3229# CHROMADB LOCK CLEANUP
3230# ========================
3233def _is_process_alive(pid: int) -> bool:
3234 """Check if a process with given PID is currently running.
3236 Cross-platform process existence check.
3238 Args:
3239 pid: Process ID to check
3241 Returns:
3242 True if process exists, False otherwise
3243 """
3244 import os
3245 import sys
3247 if sys.platform == "win32":
3248 # Windows: Use tasklist command
3249 import subprocess
3251 try:
3252 result = subprocess.run(
3253 ["tasklist", "/FI", f"PID eq {pid}"], capture_output=True, text=True, timeout=2
3254 )
3255 return str(pid) in result.stdout
3256 except Exception:
3257 return False
3258 else:
3259 # Unix/Linux/macOS: Use os.kill(pid, 0)
3260 try:
3261 os.kill(pid, 0)
3262 return True
3263 except OSError:
3264 return False
3265 except Exception:
3266 return False
3269def cleanup_stale_chromadb_locks() -> int:
3270 """Remove stale ChromaDB lock files on MCP server startup.
3272 Scans all vectordb directories and removes lock files that:
3273 1. Are older than 60 seconds (short grace period for active operations)
3274 2. Don't have an owning process running (if PID can be determined)
3276 This prevents 'Connection closed' errors from dead process locks.
3278 Returns:
3279 Number of stale locks removed
3280 """
3281 vectordb_base = Path.home() / ".stravinsky" / "vectordb"
3282 if not vectordb_base.exists():
3283 return 0 # No vectordb yet, nothing to cleanup
3285 import time
3287 removed_count = 0
3289 for project_dir in vectordb_base.iterdir():
3290 if not project_dir.is_dir():
3291 continue
3293 lock_path = project_dir / ".chromadb.lock"
3294 if not lock_path.exists():
3295 continue
3297 # Check lock age
3298 try:
3299 lock_age = time.time() - lock_path.stat().st_mtime
3300 except Exception:
3301 continue
3303 # Aggressive cleanup: remove locks older than 60 seconds
3304 # This catches recently crashed processes (old 300s was too conservative)
3305 is_stale = lock_age > 60
3307 # TODO: If lock file contains PID, check if process is alive
3308 # filelock doesn't write PID by default, but we could enhance this
3310 if is_stale:
3311 try:
3312 lock_path.unlink(missing_ok=True)
3313 removed_count += 1
3314 logger.info(f"Removed stale lock: {lock_path} (age: {lock_age:.0f}s)")
3315 except Exception as e:
3316 logger.warning(f"Could not remove stale lock {lock_path}: {e}")
3318 if removed_count > 0:
3319 logger.info(f"Startup cleanup: removed {removed_count} stale ChromaDB lock(s)")
3321 return removed_count