Coverage for session_buddy / worktree_manager.py: 75.63%
279 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Git Worktree Management for Session Management MCP Server.
4Provides high-level worktree operations and coordination with session management.
5"""
7import json
8import os
9import shutil
10import subprocess # nosec B404
11import time
12from dataclasses import dataclass, field
13from datetime import datetime
14from pathlib import Path
15from typing import Any
17from .utils.git_operations import (
18 WorktreeInfo,
19 get_worktree_info,
20 is_git_repository,
21 list_worktrees,
22)
25@dataclass(frozen=True)
26class WorktreeCreationOptions:
27 """Immutable worktree creation options."""
29 create_branch: bool = False
30 checkout_existing: bool = False
31 force: bool = False
34@dataclass
35class WorktreeValidationResult:
36 """Result of worktree validation."""
38 is_valid: bool
39 errors: list[str] = field(default_factory=list)
41 @classmethod
42 def success(cls) -> "WorktreeValidationResult":
43 """Create successful validation result."""
44 return cls(is_valid=True) # type: ignore[call-arg]
46 @classmethod
47 def error(cls, error: str) -> "WorktreeValidationResult":
48 """Create error validation result."""
49 return cls(is_valid=False, errors=[error]) # type: ignore[call-arg]
52@dataclass
53class GitOperationResult:
54 """Result of git operation execution."""
56 success: bool
57 output: str = field(default="")
58 error: str = field(default="")
60 @classmethod
61 def success_result(cls, output: str = "") -> "GitOperationResult":
62 """Create successful operation result."""
63 return cls(success=True, output=output) # type: ignore[call-arg]
65 @classmethod
66 def error_result(cls, error: str) -> "GitOperationResult":
67 """Create error operation result."""
68 return cls(success=False, error=error) # type: ignore[call-arg]
71class WorktreeManager:
72 """Manages git worktrees with session coordination."""
74 def __init__(self, session_logger: Any = None) -> None:
75 self.session_logger = session_logger
77 def _log(self, message: str, level: str = "info", **context: Any) -> None:
78 """Log messages if logger available."""
79 if self.session_logger: 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true
80 getattr(self.session_logger, level)(message, **context)
82 def _get_git_executable(self) -> str:
83 """Security: Get the full path to git executable to prevent PATH injection."""
84 git_path = shutil.which("git")
85 if not git_path: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 msg = "Git executable not found in PATH"
87 raise OSError(msg)
88 return git_path
90 def _validate_git_command(self, cmd: list[str]) -> bool:
91 """Security: Validate git command arguments to prevent injection."""
92 if not cmd or len(cmd) < 2: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 return False
95 # First argument should be the git executable
96 if not cmd[0].endswith("git"): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 return False
99 # Second argument should be a valid git subcommand
100 valid_subcommands = {
101 "worktree",
102 "status",
103 "add",
104 "commit",
105 "branch",
106 "checkout",
107 }
108 if len(cmd) > 1 and cmd[1] not in valid_subcommands: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true
109 return False
111 # Check for potentially dangerous characters in arguments
112 for arg in cmd:
113 if any(char in arg for char in (";", "&", "|", "`", "$", "\\", "\n", "\r")): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 return False
116 return True
118 def _is_safe_branch_name(self, branch: str) -> bool:
119 """Security: Validate branch name is safe for shell execution."""
120 import re
122 # Allow alphanumeric, dashes, underscores, slashes (for remote branches)
123 # Using pattern directly to avoid ValidatedPattern complexity
124 pattern = re.compile(r"^[a-zA-Z0-9_/-]+$") # REGEX OK: validated safe pattern
125 return bool(pattern.match(branch)) and len(branch) < 100
127 def _is_safe_path(self, path: Path) -> bool:
128 """Security: Validate path is safe and reasonable."""
129 try:
130 # Convert to absolute path for validation
131 abs_path = path.resolve()
133 # Check for suspicious path components
134 path_str = str(abs_path)
136 # Reject paths with null bytes or dangerous patterns
137 if "\x00" in path_str or ".." in path_str: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 return False
140 # Check path length is reasonable
141 return not len(path_str) > 500
142 except (OSError, ValueError):
143 return False
145 async def list_worktrees(self, directory: Path) -> dict[str, Any]:
146 """List all worktrees with enhanced information."""
147 if not is_git_repository(directory):
148 return {"success": False, "error": "Not a git repository", "worktrees": []}
150 try:
151 worktrees = list_worktrees(directory)
152 current_worktree = get_worktree_info(directory)
154 worktree_data = []
155 for wt in worktrees:
156 wt_data = {
157 "path": str(wt.path),
158 "branch": wt.branch,
159 "is_main": wt.is_main_worktree,
160 "is_current": current_worktree and wt.path == current_worktree.path,
161 "is_detached": wt.is_detached,
162 "is_bare": wt.is_bare,
163 "locked": wt.locked,
164 "prunable": wt.prunable,
165 "exists": wt.path.exists(),
166 }
168 # Add session info if available
169 wt_data["has_session"] = self._check_session_exists(wt.path)
171 worktree_data.append(wt_data)
173 self._log("Listed worktrees", worktrees_count=len(worktree_data))
175 return {
176 "success": True,
177 "worktrees": worktree_data,
178 "current_worktree": str(current_worktree.path)
179 if current_worktree
180 else None,
181 "total_count": len(worktree_data),
182 }
184 except Exception as e:
185 self._log(f"Failed to list worktrees: {e}", level="error")
186 return {"success": False, "error": str(e), "worktrees": []}
188 def _validate_worktree_creation_request(
189 self,
190 repository_path: Path,
191 new_path: Path,
192 branch: str,
193 ) -> WorktreeValidationResult:
194 """Validate worktree creation request. Target complexity: ≤5."""
195 if not is_git_repository(repository_path):
196 return WorktreeValidationResult.error(
197 "Source directory is not a git repository",
198 )
200 if new_path.exists():
201 return WorktreeValidationResult.error(
202 f"Target path already exists: {new_path}",
203 )
205 # Security: Validate branch name to prevent injection
206 if not branch or not self._is_safe_branch_name(branch): 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 return WorktreeValidationResult.error(
208 "Invalid branch name: must be alphanumeric with dashes/underscores only",
209 )
211 # Security: Validate path is within reasonable bounds
212 if not self._is_safe_path(new_path): 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 return WorktreeValidationResult.error(
214 "Invalid path: path must be relative to current directory structure",
215 )
217 return WorktreeValidationResult.success()
219 def _build_worktree_command(
220 self,
221 new_path: Path,
222 branch: str,
223 options: WorktreeCreationOptions,
224 ) -> list[str]:
225 """Build git worktree add command with security hardening."""
226 git_executable = self._get_git_executable()
227 cmd = [git_executable, "worktree", "add"]
229 if options.create_branch:
230 cmd.extend(["-b", branch])
231 elif options.checkout_existing: 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true
232 cmd.extend(["--track", "-B", branch])
234 cmd.extend([str(new_path), branch])
235 return cmd
237 def _execute_worktree_creation(
238 self,
239 cmd: list[str],
240 repository_path: Path,
241 ) -> subprocess.CompletedProcess[str]:
242 """Execute git worktree add with security hardening."""
243 return subprocess.run( # nosec B603 - Command validated via _validate_git_command()
244 cmd,
245 cwd=repository_path,
246 capture_output=True,
247 text=True,
248 check=True,
249 timeout=30, # Security: Prevent hanging processes
250 shell=False, # Security: Explicit shell=False to prevent injection
251 )
253 def _build_success_response_from_info(
254 self,
255 new_path: Path,
256 branch: str,
257 worktree_info: Any,
258 output: str,
259 ) -> dict[str, Any]:
260 """Build success response for worktree creation. Target complexity: ≤3."""
261 return {
262 "success": True,
263 "worktree_path": str(new_path),
264 "branch": branch,
265 "worktree_info": {
266 "path": str(worktree_info.path),
267 "branch": worktree_info.branch,
268 "is_main": worktree_info.is_main_worktree,
269 "is_detached": worktree_info.is_detached,
270 },
271 "output": output,
272 }
274 async def _execute_git_worktree_creation(
275 self,
276 new_path: Path,
277 branch: str,
278 options: WorktreeCreationOptions,
279 repository_path: Path,
280 ) -> GitOperationResult:
281 """Execute git worktree creation. Target complexity: ≤8."""
282 try:
283 # Build and validate command
284 cmd = self._build_worktree_command(new_path, branch, options)
286 # Security: Validate command before execution
287 if not self._validate_git_command(cmd): 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 return GitOperationResult.error_result(
289 "Invalid git command detected - potential security risk",
290 )
292 # Execute command
293 result = self._execute_worktree_creation(cmd, repository_path)
294 return GitOperationResult.success_result(result.stdout.strip())
296 except subprocess.CalledProcessError as e:
297 error_msg = e.stderr.strip() if e.stderr else str(e)
298 self._log(f"Failed to create worktree: {error_msg}", level="error")
299 return GitOperationResult.error_result(error_msg)
300 except Exception as e:
301 self._log(f"Unexpected error creating worktree: {e}", level="error")
302 return GitOperationResult.error_result(str(e))
304 def _verify_worktree_creation(self, new_path: Path) -> GitOperationResult:
305 """Verify worktree was created successfully. Target complexity: ≤3."""
306 worktree_info = get_worktree_info(new_path)
307 if not worktree_info: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 return GitOperationResult.error_result(
309 "Worktree was created but cannot be accessed",
310 )
311 return GitOperationResult.success_result()
313 async def create_worktree(
314 self,
315 repository_path: Path,
316 new_path: Path,
317 branch: str,
318 create_branch: bool = False,
319 checkout_existing: bool = False,
320 ) -> dict[str, Any]:
321 """Create a new worktree. Target complexity: ≤8."""
322 options = WorktreeCreationOptions(
323 create_branch=create_branch,
324 checkout_existing=checkout_existing,
325 )
327 # 1. Validate request
328 validation = self._validate_worktree_creation_request(
329 repository_path,
330 new_path,
331 branch,
332 )
333 if not validation.is_valid:
334 return {"success": False, "error": validation.errors[0]}
336 # 2. Execute git operations
337 git_result = await self._execute_git_worktree_creation(
338 new_path,
339 branch,
340 options,
341 repository_path,
342 )
343 if not git_result.success:
344 return {"success": False, "error": git_result.error}
346 # 3. Verify creation
347 verify_result = self._verify_worktree_creation(new_path)
348 if not verify_result.success: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true
349 return {"success": False, "error": verify_result.error}
351 # 4. Build success response
352 worktree_info = get_worktree_info(new_path)
353 self._log("Created worktree", path=str(new_path), branch=branch)
354 return self._build_success_response_from_info(
355 new_path,
356 branch,
357 worktree_info,
358 git_result.output,
359 )
361 async def remove_worktree(
362 self,
363 repository_path: Path,
364 worktree_path: Path,
365 force: bool = False,
366 ) -> dict[str, Any]:
367 """Remove an existing worktree."""
368 if not is_git_repository(repository_path):
369 return {
370 "success": False,
371 "error": "Source directory is not a git repository",
372 }
374 try:
375 # Build git worktree remove command with security hardening
376 git_executable = self._get_git_executable()
377 cmd = [git_executable, "worktree", "remove"]
379 if force:
380 cmd.append("--force")
382 cmd.append(str(worktree_path))
384 # Security: Validate command before execution
385 if not self._validate_git_command(cmd): 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true
386 return {
387 "success": False,
388 "error": "Invalid git command detected - potential security risk",
389 }
391 # Execute git worktree remove with security hardening
392 result = subprocess.run( # nosec B603 - Command validated via _validate_git_command()
393 cmd,
394 cwd=repository_path,
395 capture_output=True,
396 text=True,
397 check=True,
398 timeout=30, # Security: Prevent hanging processes
399 shell=False, # Security: Explicit shell=False to prevent injection
400 )
402 self._log("Removed worktree", path=str(worktree_path))
404 return {
405 "success": True,
406 "removed_path": str(worktree_path),
407 "output": result.stdout.strip() or "Worktree removed successfully",
408 }
410 except subprocess.CalledProcessError as e:
411 error_msg = e.stderr.strip() if e.stderr else str(e)
412 self._log(f"Failed to remove worktree: {error_msg}", level="error")
413 return {"success": False, "error": error_msg}
414 except Exception as e:
415 self._log(f"Unexpected error removing worktree: {e}", level="error")
416 return {"success": False, "error": str(e)}
418 async def prune_worktrees(self, repository_path: Path) -> dict[str, Any]:
419 """Prune stale worktree references."""
420 if not is_git_repository(repository_path):
421 return {"success": False, "error": "Directory is not a git repository"}
423 try:
424 # Build git worktree prune command with security hardening
425 git_executable = self._get_git_executable()
426 cmd = [git_executable, "worktree", "prune", "--verbose"]
428 # Security: Validate command before execution
429 if not self._validate_git_command(cmd): 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true
430 return {
431 "success": False,
432 "error": "Invalid git command detected - potential security risk",
433 }
435 # Execute git worktree prune with security hardening
436 result = subprocess.run( # nosec B603 - Command validated via _validate_git_command()
437 cmd,
438 cwd=repository_path,
439 capture_output=True,
440 text=True,
441 check=True,
442 timeout=30, # Security: Prevent hanging processes
443 shell=False, # Security: Explicit shell=False to prevent injection
444 )
446 output_lines = (
447 result.stdout.strip().split("\n") if result.stdout.strip() else []
448 )
449 pruned_count = len([line for line in output_lines if "Removing" in line])
451 self._log("Pruned worktrees", pruned_count=pruned_count)
453 return {
454 "success": True,
455 "pruned_count": pruned_count,
456 "output": result.stdout.strip() or "No worktrees to prune",
457 }
459 except subprocess.CalledProcessError as e:
460 error_msg = e.stderr.strip() if e.stderr else str(e)
461 self._log(f"Failed to prune worktrees: {error_msg}", level="error")
462 return {"success": False, "error": error_msg}
464 async def get_worktree_status(self, directory: Path) -> dict[str, Any]:
465 """Get comprehensive status for current worktree and all related worktrees."""
466 if not is_git_repository(directory):
467 return {"success": False, "error": "Not a git repository"}
469 try:
470 current_worktree = get_worktree_info(directory)
471 all_worktrees = list_worktrees(directory)
473 if not current_worktree:
474 return {
475 "success": False,
476 "error": "Could not determine current worktree info",
477 }
479 # Enhanced status with session coordination
480 return {
481 "success": True,
482 "current_worktree": {
483 "path": str(current_worktree.path),
484 "branch": current_worktree.branch,
485 "is_main": current_worktree.is_main_worktree,
486 "is_detached": current_worktree.is_detached,
487 "has_session": self._check_session_exists(current_worktree.path),
488 },
489 "all_worktrees": [
490 {
491 "path": str(wt.path),
492 "branch": wt.branch,
493 "is_main": wt.is_main_worktree,
494 "is_current": wt.path == current_worktree.path,
495 "exists": wt.path.exists(),
496 "has_session": self._check_session_exists(wt.path),
497 "prunable": wt.prunable,
498 }
499 for wt in all_worktrees
500 ],
501 "total_worktrees": len(all_worktrees),
502 "session_summary": self._get_session_summary(all_worktrees),
503 }
505 except Exception as e:
506 self._log(f"Failed to get worktree status: {e}", level="error")
507 return {"success": False, "error": str(e)}
509 def _check_session_exists(self, path: Path) -> bool:
510 """Check if a worktree has an active session by looking for session files."""
511 if isinstance(path, str): 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true
512 path = Path(path)
514 if not path.exists():
515 return False
517 # Check for common session indicators
518 session_indicators = [
519 path / ".git", # Git repository
520 path / ".claude", # Claude session directory
521 path / ".session", # Generic session directory
522 ]
524 # Also check for project-specific session files
525 project_files = [
526 "pyproject.toml",
527 "package.json",
528 "requirements.txt",
529 "setup.py",
530 ]
532 has_session_indicators = any(
533 indicator.exists() for indicator in session_indicators
534 )
535 has_project_files = any(
536 (path / proj_file).exists() for proj_file in project_files
537 )
539 return has_session_indicators or has_project_files
541 def _get_session_summary(self, worktrees: list[WorktreeInfo]) -> dict[str, Any]:
542 """Get summary of sessions across worktrees."""
543 active_sessions = 0
544 branches = set()
546 for wt in worktrees:
547 if self._check_session_exists(wt.path):
548 active_sessions += 1
549 branches.add(wt.branch)
551 return {
552 "active_sessions": active_sessions,
553 "unique_branches": len(branches),
554 "branches": list(branches),
555 }
557 def _save_current_session_state(self, worktree_path: Path) -> dict[str, Any] | None:
558 """Save the current session state for preservation during worktree switching."""
559 try:
560 state = {
561 "timestamp": datetime.now().isoformat(),
562 "worktree_path": str(worktree_path),
563 "working_directory": str(Path.cwd()),
564 "environment": os.environ.copy(),
565 "recent_files": self._get_recent_files(worktree_path),
566 "git_status": self._get_git_status(worktree_path),
567 }
569 # Save to a temporary file in the .claude directory
570 claude_dir = Path.home() / ".claude" / "worktree_sessions"
571 claude_dir.mkdir(parents=True, exist_ok=True)
573 state_file = claude_dir / f"session_state_{worktree_path.name}.json"
574 with state_file.open("w") as f:
575 json.dump(state, f, indent=2)
577 return state
578 except Exception as e:
579 self._log(f"Failed to save session state: {e}", level="warning")
580 return None
582 def _restore_session_state(
583 self,
584 worktree_path: Path,
585 state: dict[str, Any] | None,
586 ) -> bool:
587 """Restore session state for the target worktree."""
588 if not state:
589 return False
591 try:
592 # For now, we'll just log that we're restoring state
593 # In a more advanced implementation, we could restore environment variables,
594 # open files, IDE state, etc.
595 self._log(
596 "Session state restored",
597 worktree=worktree_path.name,
598 recent_files=len(state.get("recent_files", [])),
599 )
600 return True
601 except Exception as e:
602 self._log(f"Failed to restore session state: {e}", level="warning")
603 return False
605 def _get_recent_files(self, worktree_path: Path) -> list[str]:
606 """Get recently modified files in the worktree."""
607 try:
608 recent_files = []
609 # Get files modified in the last 24 hours
610 cutoff_time = time.time() - (24 * 60 * 60)
612 for file_path in worktree_path.rglob("*"):
613 if file_path.is_file() and not any(
614 part.startswith(".") for part in file_path.parts
615 ):
616 try:
617 if file_path.stat().st_mtime > cutoff_time:
618 recent_files.append(
619 str(file_path.relative_to(worktree_path)),
620 )
621 except (OSError, PermissionError):
622 continue
624 return recent_files[:20] # Limit to 20 most recent files
625 except Exception:
626 return []
628 def _get_git_status(self, worktree_path: Path) -> dict[str, Any]:
629 """Get git status for the worktree."""
630 try:
631 from .utils.git_operations import get_git_status
633 modified, untracked = get_git_status(worktree_path)
634 return {
635 "modified_files": modified,
636 "untracked_files": untracked,
637 "has_changes": len(modified) > 0 or len(untracked) > 0,
638 }
639 except Exception:
640 return {"modified_files": [], "untracked_files": [], "has_changes": False}
642 async def switch_worktree_context(
643 self,
644 from_path: Path,
645 to_path: Path,
646 ) -> dict[str, Any]:
647 """Coordinate switching between worktrees with session preservation."""
648 try:
649 # Validate both paths
650 if not is_git_repository(from_path):
651 return {
652 "success": False,
653 "error": f"Source path is not a git repository: {from_path}",
654 }
656 if not is_git_repository(to_path):
657 return {
658 "success": False,
659 "error": f"Target path is not a git repository: {to_path}",
660 }
662 from_worktree = get_worktree_info(from_path)
663 to_worktree = get_worktree_info(to_path)
665 if not from_worktree or not to_worktree:
666 return {
667 "success": False,
668 "error": "Could not get worktree information for context switch",
669 }
671 # Integrate with session management to preserve context
672 try:
673 # 1. Save current session state
674 session_state = self._save_current_session_state(from_path)
676 # 2. Switch working directory context
677 os.chdir(to_path)
679 # 3. Restore/create session for target worktree
680 restored_state = self._restore_session_state(to_path, session_state)
682 self._log(
683 "Context switch completed with session preservation",
684 from_branch=from_worktree.branch,
685 to_branch=to_worktree.branch,
686 )
688 return {
689 "success": True,
690 "from_worktree": {
691 "path": str(from_worktree.path),
692 "branch": from_worktree.branch,
693 },
694 "to_worktree": {
695 "path": str(to_worktree.path),
696 "branch": to_worktree.branch,
697 },
698 "context_preserved": True,
699 "session_state_saved": session_state is not None,
700 "session_state_restored": restored_state,
701 "message": f"Switched from {from_worktree.branch} to {to_worktree.branch}",
702 }
703 except Exception as session_error:
704 # Fallback to basic switching if session preservation fails
705 self._log(
706 f"Session preservation failed, using basic switching: {session_error}",
707 level="warning",
708 )
709 os.chdir(to_path)
711 self._log(
712 "Basic context switch completed",
713 from_branch=from_worktree.branch,
714 to_branch=to_worktree.branch,
715 )
717 return {
718 "success": True,
719 "from_worktree": {
720 "path": str(from_worktree.path),
721 "branch": from_worktree.branch,
722 },
723 "to_worktree": {
724 "path": str(to_worktree.path),
725 "branch": to_worktree.branch,
726 },
727 "context_preserved": False,
728 "session_error": str(session_error),
729 "message": f"Switched from {from_worktree.branch} to {to_worktree.branch} (session preservation failed)",
730 }
732 except Exception as e:
733 self._log(f"Failed to switch worktree context: {e}", level="error")
734 return {"success": False, "error": str(e)}