Source code for askemblaex.preflight

"""
askemblaex/preflight.py

Preflight checks for external service credentials.
Tests Azure CV, Azure DocInt, and OpenAI before processing begins.
If a key is missing or a connection fails, the user is prompted to
quit, continue without that service, or enter the key interactively.
"""

from __future__ import annotations

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

# ─────────────────────────────────────────────
# ANSI colours (duplicated here to avoid circular import with main)
# ─────────────────────────────────────────────

RED    = "\x1b[31m"
GREEN  = "\x1b[32m"
YELLOW = "\x1b[33m"
CYAN   = "\x1b[36m"
DIM    = "\x1b[2m"
BOLD   = "\x1b[1m"
RESET  = "\x1b[0m"

OPENAI_DEFAULT_BASE_URL = "https://api.openai.com/v1"
OPENAI_DEFAULT_MODEL    = "gpt-4o"


# ─────────────────────────────────────────────
# Result types
# ─────────────────────────────────────────────

[docs] @dataclass class ServiceStatus: name: str available: bool reason: str = "" # human-readable reason for failure env_vars: list[str] = field(default_factory=list)
[docs] @dataclass class PreflightResult: services: dict[str, ServiceStatus] = field(default_factory=dict) active_methods: set[str] = field(default_factory=set) openai_available: bool = False openai_model: str = OPENAI_DEFAULT_MODEL openai_client: object = None # openai.OpenAI if available embed_provider: Optional[str] = None
[docs] def method_available(self, method: str) -> bool: return method in self.active_methods
# ───────────────────────────────────────────── # Interactive prompt helpers # ───────────────────────────────────────────── def _prompt_choice(prompt: str, options: list[str]) -> int: """ Print numbered options and return the 1-based index chosen. Loops until a valid choice is entered. """ print(prompt) for i, opt in enumerate(options, 1): print(f" [{i}] {opt}") while True: raw = input(" > ").strip() if raw.isdigit() and 1 <= int(raw) <= len(options): return int(raw) print(f" {YELLOW}Please enter a number between 1 and {len(options)}.{RESET}") def _prompt_key(env_var: str) -> Optional[str]: """Prompt the user to enter a key value. Returns None if empty.""" raw = input(f" Enter value for {CYAN}{env_var}{RESET}: ").strip() return raw if raw else None def _set_env(key: str, value: str) -> None: os.environ[key] = value # ───────────────────────────────────────────── # Individual service checks # ───────────────────────────────────────────── def _check_azure_cv() -> ServiceStatus: endpoint = os.getenv("AZURE_VISION_ENDPOINT", "").strip() key = os.getenv("AZURE_VISION_KEY", "").strip() if not endpoint or not key: missing = [] if not endpoint: missing.append("AZURE_VISION_ENDPOINT") if not key: missing.append("AZURE_VISION_KEY") return ServiceStatus( name="azure_computer_vision", available=False, reason=f"Missing: {', '.join(missing)}", env_vars=missing, ) try: from azure.cognitiveservices.vision.computervision import ComputerVisionClient from msrest.authentication import CognitiveServicesCredentials client = ComputerVisionClient(endpoint, CognitiveServicesCredentials(key)) # Lightweight connectivity test client.list_models() return ServiceStatus(name="azure_computer_vision", available=True) except Exception as e: return ServiceStatus( name="azure_computer_vision", available=False, reason=f"Connection failed: {e}", env_vars=["AZURE_VISION_ENDPOINT", "AZURE_VISION_KEY"], ) def _check_azure_docint() -> ServiceStatus: endpoint = os.getenv("AZURE_DOCINT_ENDPOINT", "").strip() key = os.getenv("AZURE_DOCINT_KEY", "").strip() if not endpoint or not key: missing = [] if not endpoint: missing.append("AZURE_DOCINT_ENDPOINT") if not key: missing.append("AZURE_DOCINT_KEY") return ServiceStatus( name="azure_docint", available=False, reason=f"Missing: {', '.join(missing)}", env_vars=missing, ) try: import urllib.request import urllib.error # Direct HTTP ping — any response (including 401/404) confirms the # endpoint is reachable. Only connection errors mean unavailable. url = endpoint.rstrip("/") + "/documentintelligence/info?api-version=2024-02-29-preview" req = urllib.request.Request(url, headers={"Ocp-Apim-Subscription-Key": key}) try: urllib.request.urlopen(req, timeout=10) except urllib.error.HTTPError as http_err: # Any HTTP error (401, 404 etc) = endpoint reachable = OK pass # Also verify the SDK can be imported and client instantiated from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.core.credentials import AzureKeyCredential DocumentIntelligenceClient(endpoint, AzureKeyCredential(key)) return ServiceStatus(name="azure_docint", available=True) except Exception as e: return ServiceStatus( name="azure_docint", available=False, reason=f"Connection failed: {e}", env_vars=["AZURE_DOCINT_ENDPOINT", "AZURE_DOCINT_KEY"], ) def _check_openai() -> ServiceStatus: key = os.getenv("OPENAI_KEY", "").strip() model = os.getenv("OPENAI_MODEL", OPENAI_DEFAULT_MODEL).strip() if not key: return ServiceStatus( name="openai", available=False, reason="Missing: OPENAI_KEY", env_vars=["OPENAI_KEY"], ) # Default base URL if not set if not os.getenv("OPENAI_BASE_URL"): _set_env("OPENAI_BASE_URL", OPENAI_DEFAULT_BASE_URL) try: from openai import OpenAI client = OpenAI(api_key=key, base_url=os.getenv("OPENAI_BASE_URL")) # Lightweight connectivity test — list models client.models.list() return ServiceStatus(name="openai", available=True) except Exception as e: return ServiceStatus( name="openai", available=False, reason=f"Connection failed: {e}", env_vars=["OPENAI_KEY"], ) # ───────────────────────────────────────────── # Prompt for a single failing service # ───────────────────────────────────────────── def _handle_service_failure( status: ServiceStatus, needed_for: str, *, allow_continue: bool = True, max_retries: int = 3, ) -> bool: """ Interactively handle a service that failed its preflight check. Returns True if the service is now available (key was provided and connection succeeded), False if the user chose to continue without it. Calls sys.exit(1) if the user chooses to quit. """ print(f"\n{RED}[!] {status.name}{status.reason}{RESET}") for attempt in range(max_retries): options = ["Quit"] if allow_continue: options.append(f"Continue without {status.name}") if status.env_vars: options.append("Enter credentials now") choice = _prompt_choice("", options) chosen = options[choice - 1] if chosen == "Quit": print(f"\n{YELLOW}Exiting.{RESET}") sys.exit(1) if chosen.startswith("Continue without"): return False if chosen == "Enter credentials now": for var in status.env_vars: val = _prompt_key(var) if val: _set_env(var, val) # Re-check the service print(f" {DIM}Testing connection...{RESET}") service_name = status.name if service_name == "azure_computer_vision": new_status = _check_azure_cv() elif service_name == "azure_docint": new_status = _check_azure_docint() elif service_name == "openai": new_status = _check_openai() else: new_status = status if new_status.available: print(f" {GREEN}[✓] {status.name} connected.{RESET}") return True else: print(f" {RED}[!] Still failing: {new_status.reason}{RESET}") status = new_status if attempt < max_retries - 1: print(f" {DIM}({max_retries - attempt - 1} attempt(s) remaining){RESET}") print(f" {YELLOW}Max retries reached for {status.name}.{RESET}") if allow_continue: options = ["Quit", f"Continue without {status.name}"] choice = _prompt_choice("", options) if options[choice - 1] == "Quit": sys.exit(1) return False sys.exit(1) # ───────────────────────────────────────────── # Main preflight runner # ─────────────────────────────────────────────
[docs] def run_preflight( requested_methods: set[str], needs_reconcile: bool, needs_entities: bool = False, needs_embed: bool = False, *, verbose: int = 0, ) -> PreflightResult: """ Run preflight checks for all services required by the requested methods. Checks credentials and connectivity for each service. If a service fails, prompts the user to quit, continue without it, or enter credentials interactively. Args: requested_methods: Set of extraction method names requested. needs_reconcile: Whether reconciliation (OpenAI) is needed. needs_entities: Whether entity extraction (OpenAI) is needed. needs_embed: Whether embedding generation is needed. verbose: Verbosity level (0-3). Returns: PreflightResult with the final set of available methods and an OpenAI client if reconciliation is available. """ result = PreflightResult() result.active_methods = set(requested_methods) print(f"\n{BOLD}[~] Preflight checks{RESET}") # ── Azure Computer Vision ── if "azure_computer_vision" in requested_methods: print(f" Checking azure_computer_vision...", end=" ", flush=True) status = _check_azure_cv() result.services["azure_computer_vision"] = status if status.available: print(f"{GREEN}OK{RESET}") else: print(f"{RED}FAILED{RESET}") available = _handle_service_failure( status, needed_for="azure_computer_vision extraction", allow_continue=True, ) if not available: result.active_methods.discard("azure_computer_vision") result.services["azure_computer_vision"] = ServiceStatus( name="azure_computer_vision", available=False, reason="Skipped by user", ) # ── Azure Document Intelligence ── if "azure_docint" in requested_methods: print(f" Checking azure_docint...", end=" ", flush=True) status = _check_azure_docint() result.services["azure_docint"] = status if status.available: print(f"{GREEN}OK{RESET}") else: print(f"{RED}FAILED{RESET}") available = _handle_service_failure( status, needed_for="azure_docint extraction", allow_continue=True, ) if not available: result.active_methods.discard("azure_docint") result.services["azure_docint"] = ServiceStatus( name="azure_docint", available=False, reason="Skipped by user", ) # ── Local methods — no preflight needed ── for method in ("pymupdf", "pdfplumber"): if method in requested_methods: result.services[method] = ServiceStatus(name=method, available=True) if verbose >= 2: print(f" {method}: {GREEN}OK{RESET} {DIM}(local, no credentials required){RESET}") # ── OpenAI ── (needed for reconciliation and/or entity extraction) if needs_reconcile or needs_entities: print(f" Checking openai...", end=" ", flush=True) status = _check_openai() result.services["openai"] = status if status.available: print(f"{GREEN}OK{RESET}") result.openai_available = True result.openai_model = os.getenv("OPENAI_MODEL", OPENAI_DEFAULT_MODEL) from openai import OpenAI result.openai_client = OpenAI( api_key=os.getenv("OPENAI_KEY"), base_url=os.getenv("OPENAI_BASE_URL", OPENAI_DEFAULT_BASE_URL), ) else: print(f"{RED}FAILED{RESET}") available = _handle_service_failure( status, needed_for="reconciliation", allow_continue=True, ) if available: result.openai_available = True result.openai_model = os.getenv("OPENAI_MODEL", OPENAI_DEFAULT_MODEL) from openai import OpenAI result.openai_client = OpenAI( api_key=os.getenv("OPENAI_KEY"), base_url=os.getenv("OPENAI_BASE_URL", OPENAI_DEFAULT_BASE_URL), ) else: result.openai_available = False result.services["openai"] = ServiceStatus( name="openai", available=False, reason="Skipped by user", ) # ── Embedding (Ollama or OpenAI) ── if needs_embed: from .embed import detect_provider # Ollama check ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "").strip() ollama_model = os.getenv("OLLAMA_EMODEL", "").strip() openai_emodel = os.getenv("OPENAI_EMODEL", "").strip() if ollama_endpoint and ollama_model: print(f" Checking ollama embeddings...", end=" ", flush=True) try: import urllib.request, urllib.error req = urllib.request.Request( ollama_endpoint.rstrip("/") + "/api/tags", method="GET" ) try: urllib.request.urlopen(req, timeout=10) except urllib.error.HTTPError: pass # Any HTTP response = reachable result.services["ollama_embed"] = ServiceStatus( name="ollama_embed", available=True) result.embed_provider = "ollama" print(f"{GREEN}OK{RESET}") except Exception as e: print(f"{RED}FAILED{RESET}") status = ServiceStatus( name="ollama_embed", available=False, reason=f"Connection failed: {e}", env_vars=["OLLAMA_ENDPOINT", "OLLAMA_EMODEL"], ) available = _handle_service_failure(status, needed_for="embeddings", allow_continue=True) if available: result.services["ollama_embed"] = ServiceStatus(name="ollama_embed", available=True) result.embed_provider = "ollama" else: result.services["ollama_embed"] = ServiceStatus( name="ollama_embed", available=False, reason="Skipped by user") elif openai_emodel and os.getenv("OPENAI_KEY"): # Fall back to OpenAI for embeddings result.services["openai_embed"] = ServiceStatus(name="openai_embed", available=True) result.embed_provider = "openai" if verbose >= 2: print(f" openai embeddings: {GREEN}OK{RESET} {DIM}(using existing OpenAI key){RESET}") else: missing = [] if not ollama_endpoint: missing.append("OLLAMA_ENDPOINT") if not ollama_model: missing.append("OLLAMA_EMODEL") status = ServiceStatus( name="embed", available=False, reason=f"No embedding provider configured. Set OLLAMA_ENDPOINT+OLLAMA_EMODEL or OPENAI_KEY+OPENAI_EMODEL", env_vars=missing, ) print(f"\n Checking embeddings...", end=" ", flush=True) print(f"{RED}FAILED{RESET}") available = _handle_service_failure(status, needed_for="embeddings", allow_continue=True) if not available: result.embed_provider = None # ── Summary ── print(f"\n {BOLD}Services:{RESET}") for name, svc in result.services.items(): if svc.available: print(f" {GREEN}[✓]{RESET} {name}") else: print(f" {YELLOW}[~]{RESET} {name} {DIM}{svc.reason}{RESET}") if result.active_methods: print(f"\n Active methods : {', '.join(sorted(result.active_methods))}") elif requested_methods: # Methods were requested but none are available print(f"\n{RED}[!] No extraction methods available. Exiting.{RESET}") sys.exit(1) if needs_reconcile: if result.openai_available: print(f" Reconciliation : {GREEN}enabled{RESET} ({result.openai_model})") else: print(f" Reconciliation : {YELLOW}disabled{RESET}") if needs_entities: if result.openai_available: from .entities import _get_entity_model entity_model = _get_entity_model() print(f" Entity extract : {GREEN}enabled{RESET} ({entity_model})") else: print(f" Entity extract : {YELLOW}disabled{RESET}") if needs_embed: if result.embed_provider: model = os.getenv("OLLAMA_EMODEL") if result.embed_provider == "ollama" else os.getenv("OPENAI_EMODEL") print(f" Embeddings : {GREEN}enabled{RESET} ({result.embed_provider} / {model})") else: print(f" Embeddings : {YELLOW}disabled{RESET}") print() return result