"""
askemblaex/embed.py
Embedding generation for reconciled page text.
Supports two providers, in priority order:
1. Ollama (OLLAMA_ENDPOINT, OLLAMA_EMODEL)
2. OpenAI (OPENAI_KEY, OPENAI_EMODEL)
By default only embeds pages where extractions.reconciled.text is populated.
If reconciled text is missing, the page is skipped with a warning.
Environment variables:
OLLAMA_ENDPOINT Ollama server URL e.g. http://localhost:11434
OLLAMA_EMODEL Ollama embedding model e.g. nomic-embed-text
OLLAMA_EDIM Optional expected embedding dimension (for validation)
OPENAI_KEY OpenAI API key (reused from reconciliation)
OPENAI_EMODEL OpenAI embedding model e.g. text-embedding-3-small
OPENAI_EDIM Optional expected embedding dimension (for validation)
"""
from __future__ import annotations
import json
import logging
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from .env import load_env
from .metadata import load_metadata, write_metadata
from .pages import get_page_number, save_or_merge_page
load_env()
log = logging.getLogger("askemblaex.embed")
RED = "\x1b[31m"
GREEN = "\x1b[32m"
YELLOW = "\x1b[33m"
DIM = "\x1b[2m"
RESET = "\x1b[0m"
# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────
def _utc_now() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _console_error(msg: str, exc: Optional[Exception] = None, verbosity: int = 0) -> None:
print(f" {RED}[!] {msg}{RESET}", file=sys.stderr)
if exc and verbosity >= 3:
import traceback
traceback.print_exc()
elif exc and verbosity >= 1:
print(f" {DIM}{type(exc).__name__}: {exc}{RESET}", file=sys.stderr)
def _console_warn(msg: str) -> None:
print(f" {YELLOW}[~] {msg}{RESET}")
def _console_info(msg: str, verbosity: int) -> None:
if verbosity >= 1:
print(f" {msg}")
def _console_debug(msg: str, verbosity: int) -> None:
if verbosity >= 2:
print(f" {DIM}{msg}{RESET}")
# ─────────────────────────────────────────────
# Provider detection
# ─────────────────────────────────────────────
[docs]
def detect_provider() -> Optional[str]:
"""
Return the active embedding provider name based on environment variables.
Ollama takes priority over OpenAI.
Returns 'ollama', 'openai', or None if neither is configured.
"""
if os.getenv("OLLAMA_ENDPOINT") and os.getenv("OLLAMA_EMODEL"):
return "ollama"
if os.getenv("OPENAI_KEY") and os.getenv("OPENAI_EMODEL"):
return "openai"
return None
# ─────────────────────────────────────────────
# Embedding calls
# ─────────────────────────────────────────────
[docs]
def embed_ollama(text: str) -> list[float]:
"""
Generate embeddings using Ollama.
Args:
text: Text to embed.
Returns:
List of floats representing the embedding vector.
Raises:
ValueError: If OLLAMA_ENDPOINT or OLLAMA_EMODEL are not set.
RuntimeError: If the Ollama API call fails.
"""
import urllib.request
import urllib.error
endpoint = os.getenv("OLLAMA_ENDPOINT", "").rstrip("/")
model = os.getenv("OLLAMA_EMODEL", "")
if not endpoint or not model:
raise ValueError("OLLAMA_ENDPOINT and OLLAMA_EMODEL must be set.")
url = f"{endpoint}/api/embeddings"
payload = json.dumps({"model": model, "prompt": text}).encode("utf-8")
req = urllib.request.Request(
url,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=60) as resp:
data = json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
raise RuntimeError(f"Ollama HTTP {e.code}: {e.reason}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"Ollama connection failed: {e.reason}") from e
if "embedding" not in data:
raise RuntimeError(f"Ollama response missing 'embedding' key: {data}")
return data["embedding"]
[docs]
def embed_openai(text: str) -> list[float]:
"""
Generate embeddings using OpenAI.
Args:
text: Text to embed.
Returns:
List of floats representing the embedding vector.
Raises:
ValueError: If OPENAI_KEY or OPENAI_EMODEL are not set.
"""
from openai import OpenAI
key = os.getenv("OPENAI_KEY", "")
model = os.getenv("OPENAI_EMODEL", "")
if not key or not model:
raise ValueError("OPENAI_KEY and OPENAI_EMODEL must be set.")
client = OpenAI(api_key=key)
response = client.embeddings.create(input=text, model=model)
return response.data[0].embedding
[docs]
def generate_embedding(text: str, provider: str) -> list[float]:
"""
Generate an embedding vector using the specified provider.
Args:
text: Text to embed.
provider: Either ``'ollama'`` or ``'openai'``.
Returns:
Embedding vector as a list of floats.
"""
if provider == "ollama":
return embed_ollama(text)
elif provider == "openai":
return embed_openai(text)
else:
raise ValueError(f"Unknown embedding provider: {provider!r}")
# ─────────────────────────────────────────────
# Dimension validation
# ─────────────────────────────────────────────
def _expected_dim(provider: str) -> Optional[int]:
"""Return the expected embedding dimension from env vars, or None."""
var = "OLLAMA_EDIM" if provider == "ollama" else "OPENAI_EDIM"
raw = os.getenv(var, "").strip()
if raw.isdigit():
return int(raw)
return None
def _model_name(provider: str) -> str:
if provider == "ollama":
return os.getenv("OLLAMA_EMODEL", "")
return os.getenv("OPENAI_EMODEL", "")
def _validate_dim(values: list[float], provider: str) -> int:
"""
Validate embedding dimensions against expected value if configured.
Returns the actual dimension. Warns if mismatch.
"""
actual = len(values)
expected = _expected_dim(provider)
if expected is not None and actual != expected:
_console_warn(
f"Embedding dim mismatch for {provider}: "
f"expected {expected}, got {actual}"
)
log.warning(
"Embedding dim mismatch — provider=%s expected=%d actual=%d",
provider, expected, actual,
)
return actual
# ─────────────────────────────────────────────
# Page-level embedding
# ─────────────────────────────────────────────
[docs]
def embed_page_file(
page_file: Path,
doc_id: str,
page_num: int,
parent_folder: Path,
*,
provider: str,
force: bool = False,
verbosity: int = 0,
) -> bool:
"""
Embed the reconciled text of a single page file.
Loads the page JSON, reads ``extractions.reconciled.text``, generates
an embedding, and writes the result back into ``extractions.embedding``.
Args:
page_file: Path to the page JSON file.
doc_id: Document ID (content hash).
page_num: Zero-based page number.
parent_folder: Hash-keyed document folder.
provider: Embedding provider (``'ollama'`` or ``'openai'``).
force: If True, re-embed even if already embedded.
verbosity: Console verbosity level (0-3).
Returns:
True if embedding was generated, False if skipped.
"""
try:
page_data = json.loads(page_file.read_text(encoding="utf-8"))
except Exception as e:
_console_error(f"Failed to load {page_file.name}: {e}", exc=e, verbosity=verbosity)
return False
extractions = page_data.get("extractions", {})
embedding = extractions.get("embedding", {})
# Skip check — already embedded with same provider+model
if not force:
existing_values = embedding.get("values")
existing_model = embedding.get("model")
current_model = _model_name(provider)
if existing_values and existing_model == current_model:
_console_debug(
f"page.{page_num:04d} already embedded with {existing_model}, skipping",
verbosity)
return False
# Get reconciled text
reconciled = extractions.get("reconciled", {})
text = (reconciled.get("text") or "").strip()
if not text:
_console_warn(f"page.{page_num:04d} — no reconciled text, skipping")
log.warning("No reconciled text for page %d in %s", page_num, doc_id)
return False
# Generate embedding
_console_debug(f"Embedding page.{page_num:04d} ({len(text)} chars)...", verbosity)
try:
values = generate_embedding(text, provider)
except Exception as e:
_console_error(
f"Embedding failed page.{page_num:04d}: {e}", exc=e, verbosity=verbosity)
log.error("Embedding failed page %d: %s", page_num, e)
return False
dim = _validate_dim(values, provider)
# Write back into page file
now = _utc_now()
save_or_merge_page(parent_folder, doc_id, page_num, {
"embedding": {
"values": values,
"model": _model_name(provider),
"provider": provider,
"dim": dim,
"created_at": now,
}
})
_console_debug(
f"page.{page_num:04d} embedded — dim={dim} provider={provider}", verbosity)
log.info("Embedded page %d dim=%d provider=%s", page_num, dim, provider)
return True
# ─────────────────────────────────────────────
# Folder-level embedding
# ─────────────────────────────────────────────
[docs]
def embed_folder(
folder: Path,
*,
provider: str,
force: bool = False,
verbosity: int = 0,
) -> tuple[int, int]:
"""
Embed all reconciled pages in a hash-keyed document folder.
Reads each page file under ``<folder>/pages/``, generates embeddings
for pages that have reconciled text, and writes results back into each
page's ``extractions.embedding`` slot.
After all pages are processed, stamps ``extraction.steps.embeddings = true``
in the document metadata.
Args:
folder: Hash-keyed document folder.
provider: Embedding provider (``'ollama'`` or ``'openai'``).
force: If True, re-embed even if already embedded with same model.
verbosity: Console verbosity level (0-3).
Returns:
A ``(embedded_count, skipped_count)`` tuple.
"""
pages_dir = folder / "pages"
if not pages_dir.is_dir():
log.debug("No pages/ folder in %s", folder.name)
return 0, 0
doc_id = folder.name
page_files = sorted(pages_dir.glob("*.json"))
if not page_files:
log.debug("No page files in %s", pages_dir)
return 0, 0
log.info("Embedding folder %s — %d pages provider=%s", folder.name, len(page_files), provider)
if verbosity >= 1:
print(f"\n[>] {folder.name} — {len(page_files)} pages")
embedded_count = 0
skipped_count = 0
for page_file in page_files:
page_num = get_page_number(page_file)
if page_num is None:
log.warning("Could not parse page number from %s, skipping", page_file.name)
continue
ran = embed_page_file(
page_file,
doc_id=doc_id,
page_num=page_num,
parent_folder=folder,
provider=provider,
force=force,
verbosity=verbosity,
)
if ran:
embedded_count += 1
if verbosity >= 1:
print(f" {GREEN}[✓]{RESET} page.{page_num:04d}")
else:
skipped_count += 1
if verbosity >= 2:
print(f" {DIM}[~] page.{page_num:04d} skipped{RESET}")
# Stamp metadata
if embedded_count > 0:
metadata = load_metadata(folder, doc_id)
if metadata:
metadata["extraction"]["steps"]["embeddings"] = True
write_metadata(folder, doc_id, metadata)
log.info("Stamped embeddings=true in metadata for %s", folder.name)
return embedded_count, skipped_count