#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-10-13 11:01:42 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/scholar/metadata_engines/ScholarEngine.py
# ----------------------------------------
from __future__ import annotations
import os
__FILE__ = "./src/scitex/scholar/metadata_engines/ScholarEngine.py"
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------
import json
__FILE__ = __file__
import asyncio
import hashlib
import re
import time
from typing import Dict, List, Optional
import scitex_logging as logging
from tqdm import tqdm
from scitex_scholar.config import ScholarConfig
from .individual import (
ArXivEngine,
CrossRefEngine,
CrossRefLocalEngine,
OpenAlexEngine,
OpenAlexLocalEngine,
PubMedEngine,
SemanticScholarEngine,
URLDOIEngine,
)
logger = logging.getLogger(__name__)
[docs]
class ScholarEngine:
"""Aggregates metadata from multiple engines for enrichment."""
[docs]
def __init__(
self,
engines: List[str] = None,
config: ScholarConfig = None,
use_cache=True,
clear_cache=False,
):
self.name = self.__class__.__name__
self.config = config if config else ScholarConfig()
_engines_resolved = self.config.resolve("engines", engines)
self.engines: List[str] = (
list(_engines_resolved)
if isinstance(_engines_resolved, (list, tuple))
else []
)
self.use_cache = self.config.resolve("use_cache_engine", use_cache)
self._engine_instances = {}
self.rotation_manager = None
# Initialize cache
self._setup_cache(clear_cache)
[docs]
def _setup_cache(self, clear_cache=False):
"""Setup cache directory and files."""
self.cache_dir = self.config.path_manager.get_cache_engine_dir()
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_file = self.cache_dir / "search_results.json"
if clear_cache and self.cache_file.exists():
self.cache_file.unlink()
logger.info(f"{self.name}Cleared engine search cache")
self._load_cache()
[docs]
def _load_cache(self):
"""Load cache from file."""
if self.use_cache and self.cache_file.exists():
try:
with open(self.cache_file, "r") as f:
self._cache = json.load(f)
# self._cache = stx.io.load(self.cache_file)
except Exception as exc:
logger.debug(
f"ScholarEngine cache unreadable at {self.cache_file}, "
f"resetting ({type(exc).__name__}: {exc})"
)
self._cache = {}
else:
self._cache = {}
[docs]
def _save_cache(self):
"""Save cache to file."""
if self.use_cache:
try:
with open(self.cache_file, "w") as f:
json.dump(self._cache, f, indent=2)
# stx.io.save(self._cache, self.cache_file)
except Exception as e:
logger.warning(f"Failed to save engine cache: {e}")
[docs]
def _get_cache_key(self, title: str = None, doi: str = None, **kwargs):
"""Generate cache key for search parameters."""
params = {"title": title, "doi": doi, **kwargs}
# Remove None values
params = {k: v for k, v in params.items() if v is not None}
# Create hash from sorted params
import json
param_str = json.dumps(params, sort_keys=True)
return hashlib.md5(param_str.encode()).hexdigest()
[docs]
async def search_async(
self, title: str = None, doi: str = None, **kwargs
) -> Optional[Dict[str, Dict]]:
"""Search all engines and return combined results."""
def _build_readable_query_str(title, doi):
# Log failure if no results found
query_str = (
f"title: {title}"
if title
else f"doi: {doi}"
if doi
else "unknown query"
)
N_PRINT = 50
if len(query_str) < N_PRINT:
return query_str
else:
return f"{query_str[:N_PRINT]}..."
# Check cache first
cache_key = self._get_cache_key(title, doi, **kwargs)
if self.use_cache and cache_key in self._cache:
logger.debug("Using cached search result")
return self._cache[cache_key]
self._last_query_title = title
self._attempted_engines = set()
if self.rotation_manager:
paper_info = {"title": title, **kwargs}
engine_order = self.rotation_manager.get_optimal_engine_order(
paper_info, self.engines, max_engines=len(self.engines)
)
else:
engine_order = self.engines
tasks = []
for engine_name in engine_order:
engine = self._get_engine(engine_name)
if engine:
self._attempted_engines.add(engine_name)
task = self._search_engine_with_timeout(
engine, engine_name, title, doi, **kwargs
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
engine_results = {}
for ii, (engine_name, result) in enumerate(zip(engine_order, results)):
if isinstance(result, Exception):
logger.debug(f"Error from {engine_name}: {result}")
continue
if result:
logger.debug(
f"{engine_name} returned title: {result.get('basic', {}).get('title', 'N/A')}"
)
engine_results[engine_name] = result
combined_result = self._combine_metadata(engine_results)
query_str = _build_readable_query_str(title, doi)
if not combined_result:
logger.fail(f"No metadata found for {query_str}")
# Cache result
if self.use_cache and combined_result:
self._cache[cache_key] = combined_result
self._save_cache()
logger.success(f"Metadata retrieved for this query: {query_str}")
return combined_result
[docs]
async def search_batch_async(
self,
titles: List[str] = None,
dois: List[str] = None,
) -> List[Optional[Dict[str, Dict]]]:
"""Search multiple papers in batch with parallel processing."""
def _print_stats(queries, results):
failed_queries = []
found_count = 0
for query, result in zip(queries, results):
if isinstance(result, Exception):
failed_queries.append((query, str(result)))
elif result and result.get("id", {}).get("doi"):
found_count += 1
else:
failed_queries.append((query, "No metadata found"))
n_total = len(queries)
success_rate = (
round(100.0 * found_count / n_total, 1) if n_total > 0 else 0.0
)
msg = f"Search engines found {found_count}/{n_total} DOIs from publications (= {success_rate}%)"
if found_count == n_total:
logger.success(msg)
else:
logger.warning(msg)
for query, error in failed_queries:
logger.fail(f"Failed query '{query}': {error}")
if dois:
batched_metadata = [await self.search_async(doi=doi) for doi in tqdm(dois)]
_print_stats(dois, batched_metadata)
return batched_metadata
if titles:
batched_metadata = [
await self.search_async(title=title) for title in tqdm(titles)
]
_print_stats(titles, batched_metadata)
return batched_metadata
return []
def _get_engine(self, name: str):
if name not in self._engine_instances:
engine_classes = {
"URL": URLDOIEngine,
"CrossRef": CrossRefEngine,
"CrossRefLocal": CrossRefLocalEngine,
"OpenAlex": OpenAlexEngine,
"OpenAlexLocal": OpenAlexLocalEngine,
"PubMed": PubMedEngine,
"Semantic_Scholar": SemanticScholarEngine,
"arXiv": ArXivEngine,
}
if name in engine_classes:
if name == "url_doi_engine":
self._engine_instances[name] = engine_classes[name]()
elif name == "CrossRefLocal":
# Get API URL from config (supports SCITEX_SCHOLAR_CROSSREF_API_URL env var)
api_url = self.config.resolve(
"crossref_api_url", "http://127.0.0.1:3333"
)
self._engine_instances[name] = engine_classes[name](
"research@example.com", api_url=api_url
)
elif name == "OpenAlexLocal":
# Get API URL from config (supports SCITEX_SCHOLAR_OPENALEX_API_URL env var)
api_url = self.config.resolve(
"openalex_api_url", "http://127.0.0.1:31292"
)
self._engine_instances[name] = engine_classes[name](
"research@example.com", api_url=api_url
)
else:
self._engine_instances[name] = engine_classes[name](
"research@example.com"
)
return self._engine_instances.get(name)
[docs]
async def _search_engine_with_timeout(
self,
engine,
engine_name: str,
title: str = None,
doi: str = None,
timeout: int = 15,
**kwargs,
):
"""Search single engine with timeout."""
try:
# Record attempt if rotation manager available
if self.rotation_manager:
start_time = time.time()
# Create search task
loop = asyncio.get_event_loop()
search_task = loop.run_in_executor(
None, lambda: engine.search(title=title, doi=doi, **kwargs)
)
# Wait with timeout
result = await asyncio.wait_for(search_task, timeout=timeout)
# Record success
if self.rotation_manager and result:
response_time = time.time() - start_time
self.rotation_manager.record_attempt(
engine_name,
{"title": title, **kwargs},
success=True,
response_time=response_time,
)
return result
except asyncio.TimeoutError:
logger.debug(f"Timeout from {engine_name}")
if self.rotation_manager:
self.rotation_manager.record_attempt(
engine_name, {"title": title, **kwargs}, success=False
)
return None
except Exception as exc:
logger.debug(f"Error from {engine_name}: {exc}")
if self.rotation_manager:
self.rotation_manager.record_attempt(
engine_name, {"title": title, **kwargs}, success=False
)
return None
[docs]
def _identifiers_match(self, ids1: Dict, ids2: Dict) -> bool:
"""Check if any identifiers match between two papers."""
if not ids1 or not ids2:
return False
# Check each identifier type
for id_type in ["doi", "pmid", "arxiv_id", "corpus_id", "scholar_id"]:
val1 = ids1.get(id_type)
val2 = ids2.get(id_type)
if val1 and val2 and val1 == val2:
return True
return False
[docs]
def _validate_paper_consistency(self, metadata_list: List[Dict]) -> bool:
"""Check if all metadata refers to same paper by title, exact year, and first author."""
if not metadata_list or len(metadata_list) < 2:
return True
first = metadata_list[0]
first_title = (first.get("basic", {}).get("title") or "").lower().strip()
first_year = first.get("basic", {}).get("year")
first_authors = first.get("basic", {}).get("authors", [])
first_author_surname = (
first_authors[0].split()[-1].lower() if first_authors else ""
)
for metadata in metadata_list[1:]:
title = (metadata.get("basic", {}).get("title") or "").lower().strip()
year = metadata.get("basic", {}).get("year")
authors = metadata.get("basic", {}).get("authors", [])
first_author = authors[0]
if first_author:
author_surname = authors[0].split()[-1].lower() if authors else ""
else:
author_surname = ""
# Year must be exactly the same
if first_year != year:
return False
# First author surname must match
if first_author_surname and author_surname:
if first_author_surname != author_surname:
return False
# Title similarity check
if first_title and title:
first_words = set(first_title.split())
title_words = set(title.split())
overlap = len(first_words & title_words)
min_len = min(len(first_words), len(title_words))
if overlap < min_len * 0.7:
return False
return True
[docs]
def _validate_against_query(self, metadata: Dict, query_title: str) -> bool:
"""Validate metadata matches the original query with strict title matching."""
if not query_title or not metadata:
return True
paper_title = (metadata.get("basic", {}).get("title") or "").lower().strip()
if not paper_title:
return False
query_title = query_title.lower().strip()
def normalize_title(text):
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
norm_query = normalize_title(query_title)
norm_paper = normalize_title(paper_title)
# Check if normalized query is substring of paper title or vice versa
if norm_query in norm_paper or norm_paper in norm_query:
return True
# Check word-by-word exact match (order matters)
query_words = norm_query.split()
paper_words = norm_paper.split()
# Find longest common subsequence
common_seq_len = 0
for ii in range(len(paper_words)):
match_len = 0
for jj in range(min(len(query_words), len(paper_words) - ii)):
if (
ii + jj < len(paper_words)
and paper_words[ii + jj] == query_words[jj]
):
match_len += 1
else:
break
common_seq_len = max(common_seq_len, match_len)
# Require at least 80% of query words in sequence
return common_seq_len >= len(query_words) * 0.8
if __name__ == "__main__":
import asyncio
from pprint import pprint
from scitex_scholar import ScholarEngine
async def main_async():
# Query
TITLE = "Attention is All You Need"
TITLE = "Epileptic seizure forecasting with long short-term memory (LSTM) neural networks"
# DOI = "10.1038/nature14539"
# Example: Unified Engine
engine = ScholarEngine(use_cache=False)
outputs = {}
# Search by Title
outputs["metadata_by_title"] = await engine.search_async(
title=TITLE,
)
# # Search by DOI
# outputs["metadata_by_doi"] = await engine.search_async(
# doi=DOI,
# )
for k, v in outputs.items():
print("----------------------------------------")
print(k)
print("----------------------------------------")
pprint(v)
time.sleep(1)
asyncio.run(main_async())
# python -m scitex_scholar.engines.ScholarEngine
# EOF