Source code for askemblaex.embed

"""
askemblaex/embed.py

Embedding generation for reconciled page text.

Supports two providers, in priority order:
  1. Ollama  (OLLAMA_ENDPOINT, OLLAMA_EMODEL)
  2. OpenAI  (OPENAI_KEY, OPENAI_EMODEL)

By default only embeds pages where extractions.reconciled.text is populated.
If reconciled text is missing, the page is skipped with a warning.

Environment variables:
    OLLAMA_ENDPOINT   Ollama server URL e.g. http://localhost:11434
    OLLAMA_EMODEL     Ollama embedding model e.g. nomic-embed-text
    OLLAMA_EDIM       Optional expected embedding dimension (for validation)
    OPENAI_KEY        OpenAI API key (reused from reconciliation)
    OPENAI_EMODEL     OpenAI embedding model e.g. text-embedding-3-small
    OPENAI_EDIM       Optional expected embedding dimension (for validation)
"""

from __future__ import annotations

import json
import logging
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from .env import load_env
from .metadata import load_metadata, write_metadata
from .pages import get_page_number, save_or_merge_page

load_env()

log = logging.getLogger("askemblaex.embed")

RED    = "\x1b[31m"
GREEN  = "\x1b[32m"
YELLOW = "\x1b[33m"
DIM    = "\x1b[2m"
RESET  = "\x1b[0m"


# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────

def _utc_now() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _console_error(msg: str, exc: Optional[Exception] = None, verbosity: int = 0) -> None:
    print(f"    {RED}[!] {msg}{RESET}", file=sys.stderr)
    if exc and verbosity >= 3:
        import traceback
        traceback.print_exc()
    elif exc and verbosity >= 1:
        print(f"    {DIM}{type(exc).__name__}: {exc}{RESET}", file=sys.stderr)


def _console_warn(msg: str) -> None:
    print(f"    {YELLOW}[~] {msg}{RESET}")


def _console_info(msg: str, verbosity: int) -> None:
    if verbosity >= 1:
        print(f"      {msg}")


def _console_debug(msg: str, verbosity: int) -> None:
    if verbosity >= 2:
        print(f"      {DIM}{msg}{RESET}")


# ─────────────────────────────────────────────
# Provider detection
# ─────────────────────────────────────────────

[docs] def detect_provider() -> Optional[str]: """ Return the active embedding provider name based on environment variables. Ollama takes priority over OpenAI. Returns 'ollama', 'openai', or None if neither is configured. """ if os.getenv("OLLAMA_ENDPOINT") and os.getenv("OLLAMA_EMODEL"): return "ollama" if os.getenv("OPENAI_KEY") and os.getenv("OPENAI_EMODEL"): return "openai" return None
# ───────────────────────────────────────────── # Embedding calls # ─────────────────────────────────────────────
[docs] def embed_ollama(text: str) -> list[float]: """ Generate embeddings using Ollama. Args: text: Text to embed. Returns: List of floats representing the embedding vector. Raises: ValueError: If OLLAMA_ENDPOINT or OLLAMA_EMODEL are not set. RuntimeError: If the Ollama API call fails. """ import urllib.request import urllib.error endpoint = os.getenv("OLLAMA_ENDPOINT", "").rstrip("/") model = os.getenv("OLLAMA_EMODEL", "") if not endpoint or not model: raise ValueError("OLLAMA_ENDPOINT and OLLAMA_EMODEL must be set.") url = f"{endpoint}/api/embeddings" payload = json.dumps({"model": model, "prompt": text}).encode("utf-8") req = urllib.request.Request( url, data=payload, headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(req, timeout=60) as resp: data = json.loads(resp.read().decode("utf-8")) except urllib.error.HTTPError as e: raise RuntimeError(f"Ollama HTTP {e.code}: {e.reason}") from e except urllib.error.URLError as e: raise RuntimeError(f"Ollama connection failed: {e.reason}") from e if "embedding" not in data: raise RuntimeError(f"Ollama response missing 'embedding' key: {data}") return data["embedding"]
[docs] def embed_openai(text: str) -> list[float]: """ Generate embeddings using OpenAI. Args: text: Text to embed. Returns: List of floats representing the embedding vector. Raises: ValueError: If OPENAI_KEY or OPENAI_EMODEL are not set. """ from openai import OpenAI key = os.getenv("OPENAI_KEY", "") model = os.getenv("OPENAI_EMODEL", "") if not key or not model: raise ValueError("OPENAI_KEY and OPENAI_EMODEL must be set.") client = OpenAI(api_key=key) response = client.embeddings.create(input=text, model=model) return response.data[0].embedding
[docs] def generate_embedding(text: str, provider: str) -> list[float]: """ Generate an embedding vector using the specified provider. Args: text: Text to embed. provider: Either ``'ollama'`` or ``'openai'``. Returns: Embedding vector as a list of floats. """ if provider == "ollama": return embed_ollama(text) elif provider == "openai": return embed_openai(text) else: raise ValueError(f"Unknown embedding provider: {provider!r}")
# ───────────────────────────────────────────── # Dimension validation # ───────────────────────────────────────────── def _expected_dim(provider: str) -> Optional[int]: """Return the expected embedding dimension from env vars, or None.""" var = "OLLAMA_EDIM" if provider == "ollama" else "OPENAI_EDIM" raw = os.getenv(var, "").strip() if raw.isdigit(): return int(raw) return None def _model_name(provider: str) -> str: if provider == "ollama": return os.getenv("OLLAMA_EMODEL", "") return os.getenv("OPENAI_EMODEL", "") def _validate_dim(values: list[float], provider: str) -> int: """ Validate embedding dimensions against expected value if configured. Returns the actual dimension. Warns if mismatch. """ actual = len(values) expected = _expected_dim(provider) if expected is not None and actual != expected: _console_warn( f"Embedding dim mismatch for {provider}: " f"expected {expected}, got {actual}" ) log.warning( "Embedding dim mismatch — provider=%s expected=%d actual=%d", provider, expected, actual, ) return actual # ───────────────────────────────────────────── # Page-level embedding # ─────────────────────────────────────────────
[docs] def embed_page_file( page_file: Path, doc_id: str, page_num: int, parent_folder: Path, *, provider: str, force: bool = False, verbosity: int = 0, ) -> bool: """ Embed the reconciled text of a single page file. Loads the page JSON, reads ``extractions.reconciled.text``, generates an embedding, and writes the result back into ``extractions.embedding``. Args: page_file: Path to the page JSON file. doc_id: Document ID (content hash). page_num: Zero-based page number. parent_folder: Hash-keyed document folder. provider: Embedding provider (``'ollama'`` or ``'openai'``). force: If True, re-embed even if already embedded. verbosity: Console verbosity level (0-3). Returns: True if embedding was generated, False if skipped. """ try: page_data = json.loads(page_file.read_text(encoding="utf-8")) except Exception as e: _console_error(f"Failed to load {page_file.name}: {e}", exc=e, verbosity=verbosity) return False extractions = page_data.get("extractions", {}) embedding = extractions.get("embedding", {}) # Skip check — already embedded with same provider+model if not force: existing_values = embedding.get("values") existing_model = embedding.get("model") current_model = _model_name(provider) if existing_values and existing_model == current_model: _console_debug( f"page.{page_num:04d} already embedded with {existing_model}, skipping", verbosity) return False # Get reconciled text reconciled = extractions.get("reconciled", {}) text = (reconciled.get("text") or "").strip() if not text: _console_warn(f"page.{page_num:04d} — no reconciled text, skipping") log.warning("No reconciled text for page %d in %s", page_num, doc_id) return False # Generate embedding _console_debug(f"Embedding page.{page_num:04d} ({len(text)} chars)...", verbosity) try: values = generate_embedding(text, provider) except Exception as e: _console_error( f"Embedding failed page.{page_num:04d}: {e}", exc=e, verbosity=verbosity) log.error("Embedding failed page %d: %s", page_num, e) return False dim = _validate_dim(values, provider) # Write back into page file now = _utc_now() save_or_merge_page(parent_folder, doc_id, page_num, { "embedding": { "values": values, "model": _model_name(provider), "provider": provider, "dim": dim, "created_at": now, } }) _console_debug( f"page.{page_num:04d} embedded — dim={dim} provider={provider}", verbosity) log.info("Embedded page %d dim=%d provider=%s", page_num, dim, provider) return True
# ───────────────────────────────────────────── # Folder-level embedding # ─────────────────────────────────────────────
[docs] def embed_folder( folder: Path, *, provider: str, force: bool = False, verbosity: int = 0, ) -> tuple[int, int]: """ Embed all reconciled pages in a hash-keyed document folder. Reads each page file under ``<folder>/pages/``, generates embeddings for pages that have reconciled text, and writes results back into each page's ``extractions.embedding`` slot. After all pages are processed, stamps ``extraction.steps.embeddings = true`` in the document metadata. Args: folder: Hash-keyed document folder. provider: Embedding provider (``'ollama'`` or ``'openai'``). force: If True, re-embed even if already embedded with same model. verbosity: Console verbosity level (0-3). Returns: A ``(embedded_count, skipped_count)`` tuple. """ pages_dir = folder / "pages" if not pages_dir.is_dir(): log.debug("No pages/ folder in %s", folder.name) return 0, 0 doc_id = folder.name page_files = sorted(pages_dir.glob("*.json")) if not page_files: log.debug("No page files in %s", pages_dir) return 0, 0 log.info("Embedding folder %s%d pages provider=%s", folder.name, len(page_files), provider) if verbosity >= 1: print(f"\n[>] {folder.name}{len(page_files)} pages") embedded_count = 0 skipped_count = 0 for page_file in page_files: page_num = get_page_number(page_file) if page_num is None: log.warning("Could not parse page number from %s, skipping", page_file.name) continue ran = embed_page_file( page_file, doc_id=doc_id, page_num=page_num, parent_folder=folder, provider=provider, force=force, verbosity=verbosity, ) if ran: embedded_count += 1 if verbosity >= 1: print(f" {GREEN}[✓]{RESET} page.{page_num:04d}") else: skipped_count += 1 if verbosity >= 2: print(f" {DIM}[~] page.{page_num:04d} skipped{RESET}") # Stamp metadata if embedded_count > 0: metadata = load_metadata(folder, doc_id) if metadata: metadata["extraction"]["steps"]["embeddings"] = True write_metadata(folder, doc_id, metadata) log.info("Stamped embeddings=true in metadata for %s", folder.name) return embedded_count, skipped_count