#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-10-10 00:48:00 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/scholar/auth/ScholarAuthManager.py
# ----------------------------------------
from __future__ import annotations
import os
__FILE__ = "./src/scitex/scholar/auth/ScholarAuthManager.py"
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------
__FILE__ = __file__
"""
Authentication manager for coordinating multiple authentication providers.
This module manages different authentication methods and provides a unified
interface for authentication operations.
"""
from typing import Any, Dict, List, Optional
import scitex_logging as logging
from scitex_logging import AuthenticationError
from scitex_scholar.config import ScholarConfig
from .providers.BaseAuthenticator import BaseAuthenticator
from .providers.EZProxyAuthenticator import EZProxyAuthenticator
from .providers.OpenAthensAuthenticator import OpenAthensAuthenticator
from .providers.ShibbolethAuthenticator import ShibbolethAuthenticator
logger = logging.getLogger(__name__)
[docs]
class ScholarAuthManager:
"""
Manages multiple authentication providers.
This class coordinates between different authentication methods
(OpenAthens, Lean Library, etc.) and provides a unified interface.
"""
[docs]
def __init__(
self,
email_openathens: Optional[str] = os.getenv("SCITEX_SCHOLAR_OPENATHENS_EMAIL"),
email_ezproxy: Optional[str] = os.getenv("SCITEX_SCHOLAR_EZPROXY_EMAIL"),
email_shibboleth: Optional[str] = os.getenv("SCITEX_SCHOLAR_SHIBBOLETH_EMAIL"),
config: Optional[ScholarConfig] = None,
):
"""Initialize the authentication manager.
Args:
email_openathens: User's institutional email for OpenAthens authentication
email_ezproxy: User's institutional email for EZProxy authentication
email_shibboleth: User's institutional email for Shibboleth authentication
config: ScholarConfig instance (creates new if None)
"""
# Initialize config
self.name = self.__class__.__name__
self.config = config or ScholarConfig()
self.auth_session = None
self.providers: Dict[str, BaseAuthenticator] = {}
self.active_provider: Optional[str] = None
if not any([email_openathens, email_ezproxy, email_shibboleth]):
logger.warning(
f"{self.name}: "
"No authentication provider configured. "
"Set SCITEX_SCHOLAR_OPENATHENS_EMAIL or other provider email."
)
return
for email, provider_str, provider_authenticator in [
(email_openathens, "openathens", OpenAthensAuthenticator),
(email_ezproxy, "ezproxy", EZProxyAuthenticator),
(email_shibboleth, "shibboleth", ShibbolethAuthenticator),
]:
if email:
self._register_provider(
provider_str,
provider_authenticator(
email=email,
config=self.config,
),
)
[docs]
async def ensure_authenticate_async(
self,
provider_name: Optional[str] = None,
verify_live: bool = True,
**kwargs,
) -> bool:
if await self.is_authenticate_async(verify_live=verify_live):
return True
if await self.authenticate_async(provider_name=provider_name, **kwargs):
return True
raise AuthenticationError("Authentication not ensured")
[docs]
async def is_authenticate_async(self, verify_live: bool = True) -> bool:
"""Check if authenticate_async with any provider."""
# Check active provider first
if self.active_provider:
provider = self.providers[self.active_provider]
if await provider.is_authenticate_async(verify_live):
return True
# Check all other providers
for name, provider in self.providers.items():
if name != self.active_provider:
if await provider.is_authenticate_async(verify_live):
self.active_provider = name
return True
logger.debug(f"{self.name}: Not authenticate_async.")
return False
[docs]
async def authenticate_async(
self, provider_name: Optional[str] = None, **kwargs
) -> dict:
"""Authenticate with specified or active provider."""
if provider_name:
if provider_name not in self.providers:
raise AuthenticationError(f"Provider '{provider_name}' not found")
provider = self.providers[provider_name]
elif self.active_provider:
provider = self.providers[self.active_provider]
else:
raise AuthenticationError("No authentication provider configured")
self.auth_session = await provider.authenticate_async(**kwargs)
if self.auth_session and provider_name:
self.active_provider = provider_name
logger.info(f"{self.name}: Authentication succeeded by {provider_name}.")
return self.auth_session
[docs]
async def get_auth_options(self) -> dict:
await self.ensure_authenticate_async()
if self.auth_session and "cookies" in self.auth_session:
return {"storage_state": {"cookies": self.auth_session["cookies"]}}
return {}
# async def get_auth_cookies_async(self) -> List[Dict[str, Any]]:
# """Get authentication cookies from active provider."""
# provider = self.get_active_provider()
# if not provider:
# raise AuthenticationError("No active authentication provider")
# if not await provider.is_authenticate_async():
# raise AuthenticationError("Not authenticate_async")
# authenticated_cookies = await provider.get_auth_cookies_async()
# return authenticated_cookies
[docs]
async def get_auth_cookies_async(
self, essential_only: bool = True
) -> List[Dict[str, Any]]:
"""Get authentication cookies from active provider."""
await self.ensure_authenticate_async()
provider = self.get_active_provider()
# if not provider:
# raise AuthenticationError("No active authentication provider")
# if not await provider.is_authenticate_async():
# raise AuthenticationError("Not authenticate_async")
authenticated_cookies = await provider.get_auth_cookies_async()
if not essential_only:
return authenticated_cookies
# Filter by provider type
if self.active_provider == "openathens":
essential_names = [
"oa-session",
"oa-xsrf-token",
"oatmpsid",
"oalastorg",
]
elif self.active_provider == "ezproxy":
essential_names = ["ezproxy", "session_id"] # adjust as needed
elif self.active_provider == "shibboleth":
essential_names = [
"shib_session",
"_shibsession_",
] # adjust as needed
else:
return authenticated_cookies
filtered_cookies = [
cookie
for cookie in authenticated_cookies
if cookie["name"] in essential_names
]
logger.debug(
f"{self.name}: Filtered to {len(filtered_cookies)} essential cookies for {self.active_provider}"
)
return filtered_cookies
[docs]
def _register_provider(self, name: str, provider: BaseAuthenticator) -> None:
"""Register an authentication provider with email context."""
if not isinstance(provider, BaseAuthenticator):
raise TypeError(
f"Provider must inherit from BaseAuthenticator, got {type(provider)}"
)
self.providers[name] = provider
if not self.active_provider:
self.active_provider = name
logger.debug(f"{self.name}: Registered authentication provider: {name}")
[docs]
def set_active_provider(self, name: str) -> None:
"""Set the active authentication provider."""
if name not in self.providers:
raise ValueError(
f"Provider '{name}' not found. "
f"Available providers: {list(self.providers.keys())}"
)
self.active_provider = name
logger.debug(f"{self.name}: Set active authentication provider: {name}")
[docs]
def get_active_provider(self) -> Optional[BaseAuthenticator]:
"""Get the currently active provider."""
if self.active_provider:
return self.providers.get(self.active_provider)
else:
raise ValueError("Active provider not found. Please set active provider")
[docs]
async def logout_async(self) -> None:
"""Log out from all providers."""
for provider in self.providers.values():
try:
await provider.logout_async()
logger.info(f"{self.name}: Logged out from {provider}")
except Exception as e:
logger.warning(f"{self.name}: Error logging out from {provider}: {e}")
self.active_provider = None
self.auth_session = None
[docs]
def list_providers(self) -> List[str]:
"""List all registered providers."""
return list(self.providers.keys())
if __name__ == "__main__":
import asyncio
async def main():
import os
from scitex_scholar.auth import ScholarAuthManager
auth_manager = ScholarAuthManager(
email_openathens=os.getenv("SCITEX_SCHOLAR_OPENATHENS_EMAIL"),
)
auth_manager.list_providers()
await auth_manager.ensure_authenticate_async()
await auth_manager.get_auth_headers_async()
await auth_manager.get_auth_cookies_async()
asyncio.run(main())
# python -m scitex_scholar.auth.ScholarAuthManager
# EOF