from __future__ import annotations
import io
import os
import re
import time
import json
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Tuple
from collections import Counter
from pypdf import PdfReader, PdfWriter
# Azure Computer Vision (Read OCR)
from azure.cognitiveservices.vision.computervision import ComputerVisionClient
from msrest.authentication import CognitiveServicesCredentials
# Azure Document Intelligence (Layout + KV)
# pip install azure-ai-documentintelligence
from azure.ai.documentintelligence import DocumentIntelligenceClient
from azure.core.credentials import AzureKeyCredential
from askemblaex.env import load_env
load_env()
# ============================================================================
# Shared helpers / types
# ============================================================================
PathOrBytes = Union[str, Path, bytes, bytearray]
def _get_logger(logger: Optional[logging.Logger], default_name: str) -> logging.Logger:
return logger or logging.getLogger(default_name)
def _format_bytes(n: int) -> str:
mb = n / (1024 * 1024)
if mb >= 1024:
return f"{mb / 1024:.2f} GB"
return f"{mb:.2f} MB"
def _ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def _safe_as_dict(obj: Any) -> Dict[str, Any]:
"""Azure SDK models often implement as_dict(); use it if present."""
try:
fn = getattr(obj, "as_dict", None)
if callable(fn):
return fn()
except Exception:
pass
return {}
def _normalize_status(status: Any) -> str:
"""
Normalize SDK status values to simple lowercase strings.
Handles:
- OperationStatusCodes.succeeded -> "succeeded" (via .value)
- "operationstatuscodes.succeeded" -> "succeeded"
- "Succeeded" -> "succeeded"
"""
if status is None:
return "unknown"
v = getattr(status, "value", None) # Enum-like objects often have .value
if isinstance(v, str) and v:
return v.strip().lower()
s = str(status).strip().lower()
if "." in s:
s = s.rsplit(".", 1)[-1]
return s or "unknown"
# ============================================================================
# Rate limiter
# ============================================================================
[docs]
class RateLimiter:
"""
Thread-safe fixed-interval rate limiter.
rate_per_sec <= 0 disables limiting.
"""
def __init__(self, rate_per_sec: float, *, logger: Optional[logging.Logger] = None, name: str = "ratelimit"):
self.rate_per_sec = float(rate_per_sec)
self._min_interval = 0.0 if self.rate_per_sec <= 0 else (1.0 / self.rate_per_sec)
self._lock = threading.Lock()
self._next_allowed = 0.0
self._log = _get_logger(logger, name)
if self.rate_per_sec > 0:
self._log.info("RateLimiter enabled. max_rps=%.3f min_interval=%.6fs", self.rate_per_sec, self._min_interval)
else:
self._log.info("RateLimiter disabled (rate_per_sec <= 0).")
[docs]
def acquire(self, *, what: str = "request", page_index: Optional[int] = None) -> None:
if self._min_interval <= 0:
return
with self._lock:
now = time.monotonic()
if now < self._next_allowed:
time.sleep(self._next_allowed - now)
now = time.monotonic()
self._next_allowed = now + self._min_interval
# ============================================================================
# PDF I/O + splitting (single source of truth)
# ============================================================================
[docs]
def pdf_to_bytes(pdf: PathOrBytes, *, logger: Optional[logging.Logger] = None) -> bytes:
"""
Load a PDF from disk (str/Path) or accept bytes/bytearray.
"""
log = _get_logger(logger, "askemblaex.pdf")
if isinstance(pdf, (bytes, bytearray)):
b = bytes(pdf)
log.debug("PDF provided as bytes. size=%s", _format_bytes(len(b)))
return b
if isinstance(pdf, Path):
pdf = str(pdf)
if isinstance(pdf, str):
log.debug("PDF provided as file path: %s", pdf)
if not os.path.exists(pdf):
log.error("PDF file does not exist: %s", pdf)
raise FileNotFoundError(f"PDF file does not exist: {pdf}")
if not os.path.isfile(pdf):
log.error("PDF path exists but is not a file: %s", pdf)
raise FileNotFoundError(f"PDF path is not a file: {pdf}")
b = Path(pdf).read_bytes()
log.info("Read PDF from disk. path=%s size=%s", pdf, _format_bytes(len(b)))
return b
raise TypeError("pdf must be a file path (str/Path) or bytes/bytearray")
[docs]
def split_pdf_to_single_page_pdfs(pdf_bytes: bytes, *, logger: Optional[logging.Logger] = None) -> List[bytes]:
"""
Split a PDF (bytes) into one-page PDFs (list[bytes]).
"""
log = _get_logger(logger, "askemblaex.pdf")
log.info("Splitting PDF into single-page PDFs. input_size=%s", _format_bytes(len(pdf_bytes)))
reader = PdfReader(io.BytesIO(pdf_bytes))
page_count = len(reader.pages)
log.info("PDF parsed. pages=%d", page_count)
pages: List[bytes] = []
for i in range(page_count):
writer = PdfWriter()
writer.add_page(reader.pages[i])
buf = io.BytesIO()
writer.write(buf)
page_bytes = buf.getvalue()
pages.append(page_bytes)
log.debug("Created page PDF. page_index=%d page_size=%s", i, _format_bytes(len(page_bytes)))
return pages
# ============================================================================
# Computer Vision Read OCR types + functions
# ============================================================================
[docs]
@dataclass
class OCRLine:
text: str
bounding_box: Optional[List[float]] = None
words: Optional[List[Dict[str, Any]]] = None
def _extract_operation_id(operation_location: str) -> str:
return operation_location.rstrip("/").split("/")[-1]
def _extract_error_from_read_result(rr: Any) -> Optional[Dict[str, Any]]:
"""
Best-effort extraction of error payload from CV Read result.
Shape varies across SDK versions.
"""
d = _safe_as_dict(rr)
if "error" in d and isinstance(d["error"], dict):
return d["error"]
for k in ("analyze_result", "analyzeResult"):
if k in d and isinstance(d[k], dict) and isinstance(d[k].get("error"), dict):
return d[k]["error"]
return None
def _parse_read_result_to_lines(read_result: Any) -> List[OCRLine]:
"""
Convert Azure CV Read result -> list of OCRLine.
"""
lines_out: List[OCRLine] = []
analyze = getattr(read_result, "analyze_result", None)
if not analyze:
return lines_out
pages = getattr(analyze, "read_results", None) or []
for page in pages:
for line in (getattr(page, "lines", None) or []):
words_payload: List[Dict[str, Any]] = []
for w in (getattr(line, "words", None) or []):
words_payload.append({
"text": getattr(w, "text", ""),
"confidence": getattr(w, "confidence", None),
"bounding_box": getattr(w, "bounding_box", None),
})
lines_out.append(
OCRLine(
text=getattr(line, "text", "") or "",
bounding_box=getattr(line, "bounding_box", None),
words=words_payload or None,
)
)
return lines_out
def _cv_submit_read_job(
client: ComputerVisionClient,
page_pdf_bytes: bytes,
*,
page_index: int,
rate_limiter: Optional[RateLimiter],
logger: logging.Logger,
) -> str:
if rate_limiter:
rate_limiter.acquire(what="cv_submit", page_index=page_index)
logger.info("Submitting CV Read job. page_index=%d payload_size=%s", page_index, _format_bytes(len(page_pdf_bytes)))
resp = client.read_in_stream(io.BytesIO(page_pdf_bytes), raw=True)
op_loc = resp.headers.get("Operation-Location") or resp.headers.get("operation-location")
if not op_loc:
logger.error("Missing Operation-Location header. page_index=%d", page_index)
raise RuntimeError("Missing Operation-Location header from Read API response.")
op_id = _extract_operation_id(op_loc)
logger.info("Submitted CV Read job. page_index=%d operation_id=%s", page_index, op_id)
return op_id
def _cv_get_read_status(
client: ComputerVisionClient,
operation_id: str,
*,
page_index: int,
rate_limiter: Optional[RateLimiter],
) -> Any:
if rate_limiter:
rate_limiter.acquire(what="cv_poll", page_index=page_index)
return client.get_read_result(operation_id)
[docs]
def ocr_pdf_by_page_single_loop(
endpoint: str,
key: str,
pdf: PathOrBytes,
*,
max_rps: float = 8.0,
poll_batch_size: int = 25,
min_sleep_s: float = 0.25,
max_sleep_s: float = 3.0,
timeout_s: int = 7200,
per_page_timeout_s: int = 3600,
return_raw: bool = False,
logger: Optional[logging.Logger] = None,
) -> Dict[int, OCRPageResult]:
"""
OCR a PDF with Azure Computer Vision Read API (submit all pages then poll).
Notes:
- Splits the PDF into single-page PDFs in memory.
- Submits all pages, stores operation IDs, polls in batches.
- Status values are normalized; terminal states are "succeeded"/"failed".
"""
log = _get_logger(logger, "askemblaex.azure_cv")
log.info("Starting CV OCR. endpoint=%s", endpoint)
rate_limiter = RateLimiter(max_rps, logger=log, name="askemblaex.azure_cv.ratelimit") if max_rps is not None else None
client = ComputerVisionClient(endpoint, CognitiveServicesCredentials(key))
log.info("ComputerVisionClient created.")
pdf_bytes = pdf_to_bytes(pdf, logger=log)
page_pdfs = split_pdf_to_single_page_pdfs(pdf_bytes, logger=log)
page_count = len(page_pdfs)
log.info("Prepared pages. page_count=%d", page_count)
pending: List[Tuple[int, str, float]] = []
results: Dict[int, OCRPageResult] = {}
# Submit all jobs
for i, page_bytes in enumerate(page_pdfs):
try:
op_id = _cv_submit_read_job(client, page_bytes, page_index=i, rate_limiter=rate_limiter, logger=log)
pending.append((i, op_id, time.time()))
except Exception:
log.exception("CV submit failed. page_index=%d", i)
results[i] = OCRPageResult(page_index=i, status="failed", lines=[], raw=None)
last_status: Dict[int, str] = {}
overall_start = time.time()
last_progress_log = 0.0
while pending:
now = time.time()
if now - last_progress_log >= 5.0:
log.info("Polling pass. pending=%d completed=%d total=%d", len(pending), len(results), page_count)
last_progress_log = now
if now - overall_start > timeout_s:
log.error("Overall timeout. timeout_s=%d pending=%d", timeout_s, len(pending))
for (i, op_id, _) in pending:
results[i] = OCRPageResult(page_index=i, status="timeout", lines=[], raw=None)
pending.clear()
break
batch = pending[:poll_batch_size]
rest = pending[poll_batch_size:]
still_pending: List[Tuple[int, str, float]] = []
for (i, op_id, started_at) in batch:
if now - started_at > per_page_timeout_s:
log.error("Per-page timeout. page_index=%d operation_id=%s", i, op_id)
results[i] = OCRPageResult(page_index=i, status="timeout", lines=[], raw=None)
continue
try:
rr = _cv_get_read_status(client, op_id, page_index=i, rate_limiter=rate_limiter)
except Exception:
log.exception("CV poll failed. page_index=%d operation_id=%s", i, op_id)
results[i] = OCRPageResult(page_index=i, status="failed", lines=[], raw=None)
continue
status_l = _normalize_status(getattr(rr, "status", None))
prev = last_status.get(i)
if prev != status_l:
last_status[i] = status_l
log.info("Status change. page_index=%d operation_id=%s status=%s", i, op_id, status_l)
if status_l in ("succeeded", "failed"):
if status_l == "succeeded":
lines = _parse_read_result_to_lines(rr)
log.info("Completed. page_index=%d status=%s lines=%d", i, status_l, len(lines))
else:
err = _extract_error_from_read_result(rr)
if err:
log.error("Failed. page_index=%d error=%s", i, err)
else:
log.error("Failed. page_index=%d (no error payload)", i)
lines = []
results[i] = OCRPageResult(
page_index=i,
status=status_l,
lines=lines,
raw=rr if return_raw else None,
)
else:
still_pending.append((i, op_id, started_at))
pending = still_pending + rest
# Sleep strategy: fast when few pending; slower when many
if len(pending) <= 5:
sleep_s = min_sleep_s
else:
frac_pending = len(pending) / max(page_count, 1)
sleep_s = min(max_sleep_s, min_sleep_s + frac_pending * max_sleep_s)
time.sleep(sleep_s)
# Ensure all pages have an entry
for i in range(page_count):
if i not in results:
results[i] = OCRPageResult(page_index=i, status="unknown", lines=[], raw=None)
log.warning("Missing page result; filled unknown. page_index=%d", i)
counts = Counter(r.status for r in results.values())
log.info("Status breakdown: %s", dict(counts))
log.info("CV OCR finished. pages=%d succeeded=%d", len(results), sum(1 for r in results.values() if r.status == "succeeded"))
return dict(sorted(results.items(), key=lambda kv: kv[0]))
[docs]
def ocr_results_to_text_by_page(results: Dict[int, OCRPageResult]) -> Dict[int, str]:
"""Convenience: flatten OCR lines into page-index -> text."""
return {i: "\n".join(line.text for line in r.lines) for i, r in results.items()}
# ============================================================================
# Document Intelligence (Layout + KeyValuePairs + basic entities)
# ============================================================================
[docs]
@dataclass
class DIKeyValue:
key: str
value: str
confidence: Optional[float] = None
[docs]
@dataclass
class DITableCell:
row: int
col: int
text: str
kind: Optional[str] = None # e.g. "content", "columnHeader"
[docs]
@dataclass
class DITable:
row_count: int
column_count: int
cells: List[DITableCell]
[docs]
@dataclass
class DIEntity:
type: str # "date" | "name"
text: str
page: int
source: str # "text" | "kv"
confidence: Optional[float] = None
# Very lightweight date/name heuristics (swap later for real NER if desired)
_DATE_PATTERNS: List[re.Pattern] = [
re.compile(r"\b(1[6-9]\d{2}|20\d{2})\b"), # 1600-2099
re.compile(r"\b(?:\d{1,2}[-/])?\d{1,2}[-/](?:\d{2}|\d{4})\b"), # 1/2/1725, 01-02-1725
re.compile(r"\b(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)\w*\s+\d{1,2},?\s+(1[6-9]\d{2}|20\d{2})\b", re.I),
]
_NAME_PATTERN = re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,4})\b")
def _extract_entities(text: str, *, page_index: int, source: str) -> List[DIEntity]:
ents: List[DIEntity] = []
for pat in _DATE_PATTERNS:
for m in pat.finditer(text):
ents.append(DIEntity(type="date", text=m.group(0), page=page_index, source=source))
for m in _NAME_PATTERN.finditer(text):
cand = m.group(1).strip()
if len(cand) < 5:
continue
ents.append(DIEntity(type="name", text=cand, page=page_index, source=source))
return ents
[docs]
def make_docint_client(endpoint: str, key: str) -> DocumentIntelligenceClient:
"""
Create a Document Intelligence client.
endpoint: https://<resource>.cognitiveservices.azure.com/
"""
return DocumentIntelligenceClient(endpoint=endpoint, credential=AzureKeyCredential(key))
[docs]
def docint_analyze_pages_layout_kv(
endpoint: str,
key: str,
page_pdfs: List[bytes],
*,
include_key_value_pairs: bool = True,
max_rps: float = 4.0,
return_raw: bool = False,
logger: Optional[logging.Logger] = None,
) -> Dict[int, DIPageResult]:
"""
Analyze already-split single-page PDFs with Document Intelligence.
Model:
- prebuilt-layout (tables + structure)
Extras:
- features=["keyValuePairs"] when include_key_value_pairs=True
Returns:
dict[page_index] -> DIPageResult (tables + KV + basic extracted entities)
"""
log = _get_logger(logger, "askemblaex.docint")
client = make_docint_client(endpoint, key)
rate_limiter = RateLimiter(max_rps, logger=log, name="askemblaex.docint.ratelimit") if max_rps is not None else None
features = ["keyValuePairs"] if include_key_value_pairs else None
out: Dict[int, DIPageResult] = {}
for i, page_bytes in enumerate(page_pdfs):
if rate_limiter:
rate_limiter.acquire(what="di_analyze", page_index=i)
poller = client.begin_analyze_document(
model_id="prebuilt-layout",
body=io.BytesIO(page_bytes),
features=features,
)
result = poller.result()
content = (getattr(result, "content", "") or "").strip()
# Tables
tables_out: List[DITable] = []
for t in (getattr(result, "tables", None) or []):
cells: List[DITableCell] = []
for c in (getattr(t, "cells", None) or []):
cells.append(
DITableCell(
row=int(getattr(c, "row_index", 0)),
col=int(getattr(c, "column_index", 0)),
text=(getattr(c, "content", "") or "").strip(),
kind=getattr(c, "kind", None),
)
)
tables_out.append(
DITable(
row_count=int(getattr(t, "row_count", 0)),
column_count=int(getattr(t, "column_count", 0)),
cells=cells,
)
)
# Key-values
kv_out: List[DIKeyValue] = []
for kv in (getattr(result, "key_value_pairs", None) or []):
k = getattr(kv, "key", None)
v = getattr(kv, "value", None)
key_txt = (getattr(k, "content", "") or "").strip()
val_txt = (getattr(v, "content", "") or "").strip()
conf = getattr(kv, "confidence", None)
if key_txt or val_txt:
kv_out.append(DIKeyValue(key=key_txt, value=val_txt, confidence=conf))
# Entities (basic heuristics): from content + KV values
ents = []
ents.extend(_extract_entities(content, page_index=i, source="text"))
for kv in kv_out:
ents.extend(_extract_entities(kv.value, page_index=i, source="kv"))
log.info("DocInt analyzed. page_index=%d chars=%d tables=%d kv=%d ents=%d", i, len(content), len(tables_out), len(kv_out), len(ents))
out[i] = DIPageResult(
page_index=i,
content=content,
tables=tables_out,
key_values=kv_out,
entities=ents,
raw=result if return_raw else None,
)
return out
[docs]
def docint_analyze_pdf_layout_kv(
endpoint: str,
key: str,
pdf: PathOrBytes,
*,
include_key_value_pairs: bool = True,
max_rps: float = 4.0,
return_raw: bool = False,
logger: Optional[logging.Logger] = None,
) -> Dict[int, DIPageResult]:
"""
Convenience wrapper:
PDF -> split into single-page PDFs -> Document Intelligence per page
"""
log = _get_logger(logger, "askemblaex.docint")
pdf_bytes = pdf_to_bytes(pdf, logger=log)
pages = split_pdf_to_single_page_pdfs(pdf_bytes, logger=log)
return docint_analyze_pages_layout_kv(
endpoint=endpoint,
key=key,
page_pdfs=pages,
include_key_value_pairs=include_key_value_pairs,
max_rps=max_rps,
return_raw=return_raw,
logger=log,
)
# ============================================================================
# Disk output helpers
# ============================================================================
[docs]
def write_ocr_text_by_page(
out_dir: Path,
file_hash: str,
results: Dict[int, OCRPageResult],
*,
method: str = "azure_read",
logger: Optional[logging.Logger] = None,
newline: str = "\n",
) -> Dict[int, Path]:
"""
Write OCR text to disk, one file per page.
Filename format:
{hash}.{method}.page{page_index}.txt
"""
log = _get_logger(logger, "askemblaex.output")
_ensure_dir(out_dir)
safe_method = "".join(c if (c.isalnum() or c in ("-", "_")) else "_" for c in method).strip("_") or "ocr"
written: Dict[int, Path] = {}
for page_index in sorted(results.keys()):
page = results[page_index]
text = newline.join(line.text for line in page.lines) if page.lines else ""
filename = f"{file_hash}.{safe_method}.page{page_index}.txt"
p = out_dir / filename
p.write_text(text, encoding="utf-8")
written[page_index] = p
log.debug("Wrote OCR text. page_index=%d status=%s chars=%d path=%s", page_index, page.status, len(text), p)
log.info("Wrote OCR text files. count=%d out_dir=%s", len(written), out_dir)
return written
[docs]
def write_docint_page_json(
out_dir: Path,
file_hash: str,
results: Dict[int, DIPageResult],
*,
method: str = "docint_layout",
logger: Optional[logging.Logger] = None,
) -> Dict[int, Path]:
"""
Write Document Intelligence results per page as JSON.
Filename format:
{hash}.{method}.page{page_index}.json
"""
log = _get_logger(logger, "askemblaex.output")
_ensure_dir(out_dir)
safe_method = "".join(c if (c.isalnum() or c in ("-", "_")) else "_" for c in method).strip("_") or "docint"
written: Dict[int, Path] = {}
for page_index in sorted(results.keys()):
r = results[page_index]
payload = {
"page_index": r.page_index,
"content": r.content,
"tables": [
{
"row_count": t.row_count,
"column_count": t.column_count,
"cells": [{"row": c.row, "col": c.col, "text": c.text, "kind": c.kind} for c in t.cells],
}
for t in r.tables
],
"key_values": [{"key": kv.key, "value": kv.value, "confidence": kv.confidence} for kv in r.key_values],
"entities": [{"type": e.type, "text": e.text, "page": e.page, "source": e.source, "confidence": e.confidence} for e in r.entities],
}
filename = f"{file_hash}.{safe_method}.page{page_index}.json"
p = out_dir / filename
p.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
written[page_index] = p
log.debug("Wrote DocInt JSON. page_index=%d path=%s", page_index, p)
log.info("Wrote DocInt JSON files. count=%d out_dir=%s", len(written), out_dir)
return written