agenticstar_platform
AGENTICSTAR Platform SDK エンタープライズAIエージェント基盤のためのSDK
外部サービス(PostgreSQL、Qdrant、Azure Blob/S3/GCS、Content Safety/PII等)との通信機能を提供
1""" 2AGENTICSTAR Platform SDK 3エンタープライズAIエージェント基盤のためのSDK 4 5外部サービス(PostgreSQL、Qdrant、Azure Blob/S3/GCS、Content Safety/PII等)との通信機能を提供 6""" 7 8__version__ = "0.1.0" 9 10# Common utilities are internal - not exported as public API 11# Use agenticstar_platform.common for internal SDK development only 12 13# Database module 14from .db import PostgreSQLManager, PostgreSQLConfig, AzureADConfig, DataAccess, ConfigAccess 15 16# RAG module 17from .rag import EmbeddingGenerator, EmbeddingConfig, QdrantManager, QdrantConfig 18 19# Events module 20from .events import ( 21 EventType, 22 SubEventType, 23 StreamingEvent, 24 SequencedEvent, 25 ExecutionMessage, 26 EventEmitter, 27 EventHandler, 28 create_sse_handler, 29 create_json_handler, 30 # Event Handlers(マーケットプレイスUI連携用) 31 DatabaseEventHandler, 32 WebhookEventHandler, 33 CompositeEventHandler, 34 create_marketplace_handler, 35) 36 37# Memory module 38from .memory import ( 39 SemanticMemoryClient, 40 SemanticMemoryConfig, 41 EpisodicMemoryClient, 42 EpisodicMemoryConfig, 43 LLMProviderConfig, 44 normalize_group_id, 45) 46 47# Storage module 48from .storage import ( 49 StorageProvider, 50 StorageConfig, 51 AzureBlobConfig, 52 S3Config, 53 GCSConfig, 54 AzureBlobStorageClient, 55 S3StorageClient, 56 GCSStorageClient, 57 UploadResult, 58 DownloadResult, 59 ObjectInfo, 60 ListResult, 61) 62 63# Security module 64from .security import ( 65 SecurityProvider, 66 ContentCategory, 67 PIICategory, 68 SecurityError, 69 SecurityConfigError, 70 SecurityAPIError, 71 ContentModerationResult, 72 PromptShieldResult, 73 PIIEntity, 74 PIIDetectionResult, 75 SecurityCheckResult, 76 AzureSecurityConfig, 77 AWSSecurityConfig, 78 GCPSecurityConfig, 79 AzureSecurityClient, 80 AWSSecurityClient, 81 GCPSecurityClient, 82) 83 84# Auth module 85from .auth import ( 86 LibreChatAuthClient, 87 LibreChatAuthConfig, 88 AuthError, 89 AuthConfigError, 90 AuthAPIError, 91 AuthUnauthorizedError, 92 AuthNotFoundError, 93 AuthRateLimitError, 94 ApiUser, 95 DeviceInfo, 96 LoginHistoryEntry, 97 UserPagination, 98 MCPTokenInfo, 99 MCPTokenError, 100 OAuthProviderName, 101 GetUsersResult, 102 GetUserResult, 103 GetMCPTokensResult, 104) 105 106__all__ = [ 107 # Database 108 "PostgreSQLManager", 109 "PostgreSQLConfig", 110 "AzureADConfig", 111 "DataAccess", 112 "ConfigAccess", 113 # RAG 114 "EmbeddingGenerator", 115 "EmbeddingConfig", 116 "QdrantManager", 117 "QdrantConfig", 118 # Events 119 "EventType", 120 "SubEventType", 121 "StreamingEvent", 122 "SequencedEvent", 123 "ExecutionMessage", 124 "EventEmitter", 125 "EventHandler", 126 "create_sse_handler", 127 "create_json_handler", 128 # Event Handlers(マーケットプレイスUI連携用) 129 "DatabaseEventHandler", 130 "WebhookEventHandler", 131 "CompositeEventHandler", 132 "create_marketplace_handler", 133 # Memory 134 "SemanticMemoryClient", 135 "SemanticMemoryConfig", 136 "EpisodicMemoryClient", 137 "EpisodicMemoryConfig", 138 "LLMProviderConfig", 139 "normalize_group_id", 140 # Storage 141 "StorageProvider", 142 "StorageConfig", 143 "AzureBlobConfig", 144 "S3Config", 145 "GCSConfig", 146 "AzureBlobStorageClient", 147 "S3StorageClient", 148 "GCSStorageClient", 149 "UploadResult", 150 "DownloadResult", 151 "ObjectInfo", 152 "ListResult", 153 # Security 154 "SecurityProvider", 155 "ContentCategory", 156 "PIICategory", 157 "SecurityError", 158 "SecurityConfigError", 159 "SecurityAPIError", 160 "ContentModerationResult", 161 "PromptShieldResult", 162 "PIIEntity", 163 "PIIDetectionResult", 164 "SecurityCheckResult", 165 "AzureSecurityConfig", 166 "AWSSecurityConfig", 167 "GCPSecurityConfig", 168 "AzureSecurityClient", 169 "AWSSecurityClient", 170 "GCPSecurityClient", 171 # Auth 172 "LibreChatAuthClient", 173 "LibreChatAuthConfig", 174 "AuthError", 175 "AuthConfigError", 176 "AuthAPIError", 177 "AuthUnauthorizedError", 178 "AuthNotFoundError", 179 "AuthRateLimitError", 180 "ApiUser", 181 "DeviceInfo", 182 "LoginHistoryEntry", 183 "UserPagination", 184 "MCPTokenInfo", 185 "MCPTokenError", 186 "OAuthProviderName", 187 "GetUsersResult", 188 "GetUserResult", 189 "GetMCPTokensResult", 190]
26class PostgreSQLManager: 27 """PostgreSQL接続とクエリ実行クラス 28 29 Azure AD認証またはユーザー名/パスワード認証でPostgreSQLに接続し、 30 クエリを実行します。接続プールを使用して効率的に接続を管理します。 31 32 Features: 33 - Azure AD認証対応(トークン自動更新) 34 - 標準認証(ユーザー名/パスワード)対応 35 - 接続プール管理 36 - 非同期コンテキストマネージャー対応 37 - SQL Injection防止(パラメータ化クエリ) 38 39 Example: 40 >>> from agenticstar_platform import PostgreSQLManager, PostgreSQLConfig 41 >>> 42 >>> # 設定を作成 43 >>> config = PostgreSQLConfig.from_toml("config.toml") 44 >>> 45 >>> # コンテキストマネージャーで使用(推奨) 46 >>> async with PostgreSQLManager(config) as db: 47 ... # SELECT実行 48 ... result = await db.fetch_all("SELECT * FROM users WHERE status = $1", ("active",)) 49 ... print(f"Active users: {len(result)}") 50 ... 51 ... # INSERT実行 52 ... await db.execute( 53 ... "INSERT INTO logs (message) VALUES ($1)", 54 ... ("Operation completed",) 55 ... ) 56 >>> 57 >>> # 明示的な初期化・クローズ 58 >>> db = PostgreSQLManager(config) 59 >>> await db.initialize() 60 >>> try: 61 ... result = await db.fetch_one("SELECT COUNT(*) as cnt FROM users") 62 ... print(f"Total users: {result['cnt']}") 63 ... finally: 64 ... await db.close() 65 """ 66 67 def __init__(self, config: PostgreSQLConfig): 68 self.config = config 69 self.pool: Optional[asyncpg.Pool] = None 70 self._access_token: Optional[str] = None 71 self._token_expiry: Optional[float] = None 72 self._token_refresh_lock: Optional[asyncio.Lock] = None 73 74 async def __aenter__(self) -> "PostgreSQLManager": 75 """Context Manager: 開始時に接続プールを初期化""" 76 await self.initialize() 77 return self 78 79 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 80 """Context Manager: 終了時に接続プールをクローズ""" 81 await self.close() 82 83 async def _get_azure_ad_token(self) -> str: 84 """Azure AD認証でアクセストークンを取得""" 85 logger.info("[IN] _get_azure_ad_token: Getting Azure AD token") 86 87 if not self.config.azure_ad: 88 logger.error("[ERROR] _get_azure_ad_token: Azure AD configuration is required") 89 raise ValueError("Azure AD configuration is required") 90 91 # トークンの有効期限をチェック(5分前に更新) 92 if ( 93 self._access_token 94 and self._token_expiry 95 and time.time() < self._token_expiry - 300 96 ): 97 logger.info("[OUT] _get_azure_ad_token: Using cached token") 98 return self._access_token 99 100 # トークン取得中の同期処理(複数API同時実行時の重複取得防止) 101 # Note: Lockはinitialize()で作成済み(イベントループ紐付けを保証) 102 async with self._token_refresh_lock: 103 # ロック取得後に再チェック(他のAPIコールが既に更新した可能性) 104 if ( 105 self._access_token 106 and self._token_expiry 107 and time.time() < self._token_expiry - 300 108 ): 109 return self._access_token 110 111 try: 112 credential = ClientSecretCredential( 113 tenant_id=self.config.azure_ad.tenant_id, 114 client_id=self.config.azure_ad.client_id, 115 client_secret=self.config.azure_ad.client_secret, 116 ) 117 118 # Try different scopes for Azure PostgreSQL Flexible Server 119 scopes_to_try = [ 120 "https://ossrdbms-aad.database.windows.net/.default", 121 "https://database.windows.net/.default", 122 "https://management.azure.com/.default", 123 ] 124 125 token = None 126 last_error = None 127 128 for scope in scopes_to_try: 129 try: 130 token = await asyncio.to_thread(credential.get_token, scope) 131 break 132 except Exception as e: 133 last_error = e 134 continue 135 136 if not token: 137 raise last_error or Exception("All token scopes failed") 138 139 self._access_token = token.token 140 self._token_expiry = token.expires_on 141 142 logger.info("[OUT] _get_azure_ad_token: Azure AD token acquired successfully") 143 return self._access_token 144 145 except Exception as e: 146 logger.error( 147 "[ERROR] _get_azure_ad_token: Failed to acquire Azure AD token - error_type=%s", 148 type(e).__name__, 149 ) 150 raise 151 152 async def _create_connection_params(self) -> Dict[str, Any]: 153 """接続パラメータを作成""" 154 logger.info("[IN] _create_connection_params: Creating database connection parameters") 155 params = { 156 "host": self.config.host, 157 "port": self.config.port, 158 "database": self.config.database, 159 "command_timeout": self.config.command_timeout, 160 } 161 162 # SSL設定(disable以外は設定する) 163 ssl_mode = getattr(self.config, 'ssl_mode', 'require') 164 if ssl_mode != "disable": 165 params["ssl"] = ssl_mode 166 167 if self.config.use_azure_ad: 168 token = await self._get_azure_ad_token() 169 170 # Use configured username if available, otherwise fallback to client_id 171 if self.config.azure_ad and self.config.azure_ad.username: 172 username = self.config.azure_ad.username 173 elif self.config.azure_ad: 174 username = self.config.azure_ad.client_id 175 else: 176 raise ValueError("Azure AD configuration is required when use_azure_ad is True") 177 178 params.update( 179 { 180 "user": username, 181 "password": token, 182 } 183 ) 184 else: 185 # 標準認証の場合(ユーザー名/パスワード) 186 if not self.config.username or not self.config.password: 187 raise ValueError("Username and password are required when use_azure_ad is False") 188 189 params.update( 190 { 191 "user": self.config.username, 192 "password": self.config.password, 193 } 194 ) 195 196 logger.info("[OUT] _create_connection_params: Connection parameters created successfully") 197 return params 198 199 async def _ensure_initialized(self) -> None: 200 """接続プールが初期化されていることを保証""" 201 if not self.pool: 202 await self.initialize() 203 204 async def initialize(self) -> None: 205 """データベース接続プールを初期化""" 206 logger.info("[IN] initialize: Initializing PostgreSQL connection pool") 207 try: 208 # トークンリフレッシュ用のLockを作成(イベントループ内で作成することで紐付けを保証) 209 if self._token_refresh_lock is None: 210 self._token_refresh_lock = asyncio.Lock() 211 212 logger.info("[IN] initialize.create_connection_params: Creating connection parameters") 213 connection_params = await self._create_connection_params() 214 logger.info("[OUT] initialize.create_connection_params: Connection parameters created") 215 216 logger.info( 217 "[IN] initialize.create_pool: Creating connection pool - min_size=%s, max_size=%s", 218 self.config.pool_min_size, 219 self.config.pool_max_size, 220 ) 221 self.pool = await asyncpg.create_pool( 222 min_size=self.config.pool_min_size, 223 max_size=self.config.pool_max_size, 224 **connection_params, 225 ) 226 logger.info("[OUT] initialize.create_pool: Connection pool created successfully") 227 logger.info("[OUT] initialize: PostgreSQL connection pool initialized successfully") 228 229 except Exception as e: 230 logger.error( 231 "[ERROR] initialize: PostgreSQL initialization failed - error_type=%s, error=%s", 232 type(e).__name__, 233 str(e), 234 ) 235 raise 236 237 async def close(self) -> None: 238 """接続プールを閉じる 239 240 Note: 241 例外が発生してもリソースを確実にクリーンアップします。 242 """ 243 logger.info("[IN] close: Closing PostgreSQL connection pool") 244 try: 245 if self.pool: 246 await self.pool.close() 247 logger.info("[OUT] close: PostgreSQL connection pool closed successfully") 248 else: 249 logger.info("[OUT] close: No connection pool to close") 250 except Exception as e: 251 logger.warning(f"Error closing PostgreSQL connection pool: {e}") 252 finally: 253 self.pool = None 254 self._access_token = None 255 self._token_expiry = None 256 257 async def execute_query( 258 self, query: str, params: tuple = () 259 ) -> Dict[str, Any]: 260 """ 261 汎用クエリ実行メソッド 262 263 SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 264 結果を統一フォーマットで返します。 265 266 Args: 267 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 268 params: クエリパラメータ(タプル) 269 270 Returns: 271 Dict[str, Any]: 実行結果 272 - success: bool - 実行成功/失敗 273 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) 274 - error: Optional[str] - エラーメッセージ(失敗時のみ) 275 - error_code: Optional[str] - エラーコード(失敗時のみ) 276 277 Example: 278 >>> # SELECTクエリ 279 >>> result = await db.execute_query( 280 ... "SELECT * FROM users WHERE created_at > $1", 281 ... (datetime(2024, 1, 1),) 282 ... ) 283 >>> if result["success"]: 284 ... for user in result["data"]: 285 ... print(user["name"]) 286 >>> 287 >>> # INSERT with RETURNING 288 >>> result = await db.execute_query( 289 ... "INSERT INTO users (name) VALUES ($1) RETURNING id", 290 ... ("John",) 291 ... ) 292 >>> if result["success"]: 293 ... new_id = result["data"][0]["id"] 294 """ 295 logger.info( 296 "[IN] execute_query: Executing SQL query - query_type=%s, params_count=%s", 297 (query.strip().upper().split()[0] if query.strip() else "EMPTY"), 298 len(params), 299 ) 300 try: 301 await self._ensure_initialized() 302 303 async with self.pool.acquire() as conn: 304 # クエリの種類を判定 305 normalized_query = query.strip().upper() 306 is_select = normalized_query.startswith("SELECT") 307 has_returning = "RETURNING" in normalized_query 308 309 if is_select or has_returning: 310 # SELECT文またはRETURNING句がある場合はfetchを使用 311 rows = await conn.fetch(query, *params) 312 # Rowオブジェクトを辞書に変換 313 result_data = [dict(row) for row in rows] 314 logger.info( 315 "[OUT] execute_query: Query executed successfully - result_count=%s", 316 len(result_data), 317 ) 318 return {"success": True, "data": result_data} 319 else: 320 # INSERT/UPDATE/DELETE等(RETURNING句なし)の場合はexecuteを使用 321 result = await conn.execute(query, *params) 322 logger.info("[OUT] execute_query: Query executed successfully") 323 return {"success": True, "data": {"result": result}} 324 325 except Exception as e: 326 logger.error( 327 "[ERROR] execute_query: Query execution failed - error_type=%s, error=%s", 328 type(e).__name__, 329 str(e), 330 ) 331 return { 332 "success": False, 333 "data": None, 334 "error": f"Database query error: {str(e)}", 335 "error_code": "DB_QUERY_ERROR", 336 } 337 338 async def fetch_one( 339 self, query: str, params: tuple = () 340 ) -> Optional[Dict[str, Any]]: 341 """ 342 単一行を取得 343 344 1行だけ返すことが期待されるクエリに使用します。 345 結果が複数行ある場合でも最初の1行のみ返します。 346 347 Args: 348 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 349 params: クエリパラメータ(タプル) 350 351 Returns: 352 Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone 353 354 Example: 355 >>> # ユーザーを1件取得 356 >>> user = await db.fetch_one( 357 ... "SELECT * FROM users WHERE id = $1", 358 ... (user_id,) 359 ... ) 360 >>> if user: 361 ... print(f"Found: {user['name']}") 362 ... else: 363 ... print("User not found") 364 >>> 365 >>> # 集計結果を取得 366 >>> count = await db.fetch_one("SELECT COUNT(*) as total FROM users") 367 >>> print(f"Total: {count['total']}") 368 """ 369 await self._ensure_initialized() 370 async with self.pool.acquire() as conn: 371 row = await conn.fetchrow(query, *params) 372 return dict(row) if row else None 373 374 async def fetch_all( 375 self, query: str, params: tuple = () 376 ) -> List[Dict[str, Any]]: 377 """ 378 全行を取得 379 380 複数行を返すSELECTクエリに使用します。 381 結果がない場合は空のリストを返します。 382 383 Args: 384 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 385 params: クエリパラメータ(タプル) 386 387 Returns: 388 List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式) 389 390 Example: 391 >>> # 全ユーザーを取得 392 >>> users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") 393 >>> for user in users: 394 ... print(f"{user['id']}: {user['name']}") 395 >>> 396 >>> # 条件付きで取得 397 >>> active_users = await db.fetch_all( 398 ... "SELECT * FROM users WHERE status = $1 AND role = $2", 399 ... ("active", "admin") 400 ... ) 401 >>> print(f"Found {len(active_users)} admin users") 402 """ 403 await self._ensure_initialized() 404 async with self.pool.acquire() as conn: 405 rows = await conn.fetch(query, *params) 406 return [dict(row) for row in rows] 407 408 async def execute(self, query: str, params: tuple = ()) -> str: 409 """ 410 実行のみ(結果を返さないクエリ用) 411 412 INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 413 DDL文(CREATE TABLE等)の実行にも使用できます。 414 415 Args: 416 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 417 params: クエリパラメータ(タプル) 418 419 Returns: 420 str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3") 421 422 Example: 423 >>> # データ挿入 424 >>> result = await db.execute( 425 ... "INSERT INTO logs (level, message) VALUES ($1, $2)", 426 ... ("INFO", "Operation completed") 427 ... ) 428 >>> print(result) # "INSERT 0 1" 429 >>> 430 >>> # データ更新 431 >>> result = await db.execute( 432 ... "UPDATE users SET status = $1 WHERE last_login < $2", 433 ... ("inactive", datetime(2024, 1, 1)) 434 ... ) 435 >>> print(result) # "UPDATE 15" 436 >>> 437 >>> # データ削除 438 >>> result = await db.execute( 439 ... "DELETE FROM sessions WHERE expired_at < $1", 440 ... (datetime.now(),) 441 ... ) 442 >>> print(result) # "DELETE 100" 443 """ 444 await self._ensure_initialized() 445 async with self.pool.acquire() as conn: 446 return await conn.execute(query, *params)
PostgreSQL接続とクエリ実行クラス
Azure AD認証またはユーザー名/パスワード認証でPostgreSQLに接続し、 クエリを実行します。接続プールを使用して効率的に接続を管理します。
Features: - Azure AD認証対応(トークン自動更新) - 標準認証(ユーザー名/パスワード)対応 - 接続プール管理 - 非同期コンテキストマネージャー対応 - SQL Injection防止(パラメータ化クエリ)
Example:
from agenticstar_platform import PostgreSQLManager, PostgreSQLConfig
設定を作成
config = PostgreSQLConfig.from_toml("config.toml")
コンテキストマネージャーで使用(推奨)
async with PostgreSQLManager(config) as db: ... # SELECT実行 ... result = await db.fetch_all("SELECT * FROM users WHERE status = $1", ("active",)) ... print(f"Active users: {len(result)}") ... ... # INSERT実行 ... await db.execute( ... "INSERT INTO logs (message) VALUES ($1)", ... ("Operation completed",) ... )
明示的な初期化・クローズ
db = PostgreSQLManager(config) await db.initialize() try: ... result = await db.fetch_one("SELECT COUNT(*) as cnt FROM users") ... print(f"Total users: {result['cnt']}") ... finally: ... await db.close()
204 async def initialize(self) -> None: 205 """データベース接続プールを初期化""" 206 logger.info("[IN] initialize: Initializing PostgreSQL connection pool") 207 try: 208 # トークンリフレッシュ用のLockを作成(イベントループ内で作成することで紐付けを保証) 209 if self._token_refresh_lock is None: 210 self._token_refresh_lock = asyncio.Lock() 211 212 logger.info("[IN] initialize.create_connection_params: Creating connection parameters") 213 connection_params = await self._create_connection_params() 214 logger.info("[OUT] initialize.create_connection_params: Connection parameters created") 215 216 logger.info( 217 "[IN] initialize.create_pool: Creating connection pool - min_size=%s, max_size=%s", 218 self.config.pool_min_size, 219 self.config.pool_max_size, 220 ) 221 self.pool = await asyncpg.create_pool( 222 min_size=self.config.pool_min_size, 223 max_size=self.config.pool_max_size, 224 **connection_params, 225 ) 226 logger.info("[OUT] initialize.create_pool: Connection pool created successfully") 227 logger.info("[OUT] initialize: PostgreSQL connection pool initialized successfully") 228 229 except Exception as e: 230 logger.error( 231 "[ERROR] initialize: PostgreSQL initialization failed - error_type=%s, error=%s", 232 type(e).__name__, 233 str(e), 234 ) 235 raise
データベース接続プールを初期化
237 async def close(self) -> None: 238 """接続プールを閉じる 239 240 Note: 241 例外が発生してもリソースを確実にクリーンアップします。 242 """ 243 logger.info("[IN] close: Closing PostgreSQL connection pool") 244 try: 245 if self.pool: 246 await self.pool.close() 247 logger.info("[OUT] close: PostgreSQL connection pool closed successfully") 248 else: 249 logger.info("[OUT] close: No connection pool to close") 250 except Exception as e: 251 logger.warning(f"Error closing PostgreSQL connection pool: {e}") 252 finally: 253 self.pool = None 254 self._access_token = None 255 self._token_expiry = None
接続プールを閉じる
Note: 例外が発生してもリソースを確実にクリーンアップします。
257 async def execute_query( 258 self, query: str, params: tuple = () 259 ) -> Dict[str, Any]: 260 """ 261 汎用クエリ実行メソッド 262 263 SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 264 結果を統一フォーマットで返します。 265 266 Args: 267 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 268 params: クエリパラメータ(タプル) 269 270 Returns: 271 Dict[str, Any]: 実行結果 272 - success: bool - 実行成功/失敗 273 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) 274 - error: Optional[str] - エラーメッセージ(失敗時のみ) 275 - error_code: Optional[str] - エラーコード(失敗時のみ) 276 277 Example: 278 >>> # SELECTクエリ 279 >>> result = await db.execute_query( 280 ... "SELECT * FROM users WHERE created_at > $1", 281 ... (datetime(2024, 1, 1),) 282 ... ) 283 >>> if result["success"]: 284 ... for user in result["data"]: 285 ... print(user["name"]) 286 >>> 287 >>> # INSERT with RETURNING 288 >>> result = await db.execute_query( 289 ... "INSERT INTO users (name) VALUES ($1) RETURNING id", 290 ... ("John",) 291 ... ) 292 >>> if result["success"]: 293 ... new_id = result["data"][0]["id"] 294 """ 295 logger.info( 296 "[IN] execute_query: Executing SQL query - query_type=%s, params_count=%s", 297 (query.strip().upper().split()[0] if query.strip() else "EMPTY"), 298 len(params), 299 ) 300 try: 301 await self._ensure_initialized() 302 303 async with self.pool.acquire() as conn: 304 # クエリの種類を判定 305 normalized_query = query.strip().upper() 306 is_select = normalized_query.startswith("SELECT") 307 has_returning = "RETURNING" in normalized_query 308 309 if is_select or has_returning: 310 # SELECT文またはRETURNING句がある場合はfetchを使用 311 rows = await conn.fetch(query, *params) 312 # Rowオブジェクトを辞書に変換 313 result_data = [dict(row) for row in rows] 314 logger.info( 315 "[OUT] execute_query: Query executed successfully - result_count=%s", 316 len(result_data), 317 ) 318 return {"success": True, "data": result_data} 319 else: 320 # INSERT/UPDATE/DELETE等(RETURNING句なし)の場合はexecuteを使用 321 result = await conn.execute(query, *params) 322 logger.info("[OUT] execute_query: Query executed successfully") 323 return {"success": True, "data": {"result": result}} 324 325 except Exception as e: 326 logger.error( 327 "[ERROR] execute_query: Query execution failed - error_type=%s, error=%s", 328 type(e).__name__, 329 str(e), 330 ) 331 return { 332 "success": False, 333 "data": None, 334 "error": f"Database query error: {str(e)}", 335 "error_code": "DB_QUERY_ERROR", 336 }
汎用クエリ実行メソッド
SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 結果を統一フォーマットで返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: Dict[str, Any]: 実行結果 - success: bool - 実行成功/失敗 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) - error: Optional[str] - エラーメッセージ(失敗時のみ) - error_code: Optional[str] - エラーコード(失敗時のみ)
Example:
SELECTクエリ
result = await db.execute_query( ... "SELECT * FROM users WHERE created_at > $1", ... (datetime(2024, 1, 1),) ... ) if result["success"]: ... for user in result["data"]: ... print(user["name"])
INSERT with RETURNING
result = await db.execute_query( ... "INSERT INTO users (name) VALUES ($1) RETURNING id", ... ("John",) ... ) if result["success"]: ... new_id = result["data"][0]["id"]
338 async def fetch_one( 339 self, query: str, params: tuple = () 340 ) -> Optional[Dict[str, Any]]: 341 """ 342 単一行を取得 343 344 1行だけ返すことが期待されるクエリに使用します。 345 結果が複数行ある場合でも最初の1行のみ返します。 346 347 Args: 348 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 349 params: クエリパラメータ(タプル) 350 351 Returns: 352 Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone 353 354 Example: 355 >>> # ユーザーを1件取得 356 >>> user = await db.fetch_one( 357 ... "SELECT * FROM users WHERE id = $1", 358 ... (user_id,) 359 ... ) 360 >>> if user: 361 ... print(f"Found: {user['name']}") 362 ... else: 363 ... print("User not found") 364 >>> 365 >>> # 集計結果を取得 366 >>> count = await db.fetch_one("SELECT COUNT(*) as total FROM users") 367 >>> print(f"Total: {count['total']}") 368 """ 369 await self._ensure_initialized() 370 async with self.pool.acquire() as conn: 371 row = await conn.fetchrow(query, *params) 372 return dict(row) if row else None
単一行を取得
1行だけ返すことが期待されるクエリに使用します。 結果が複数行ある場合でも最初の1行のみ返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone
Example:
ユーザーを1件取得
user = await db.fetch_one( ... "SELECT * FROM users WHERE id = $1", ... (user_id,) ... ) if user: ... print(f"Found: {user['name']}") ... else: ... print("User not found")
集計結果を取得
count = await db.fetch_one("SELECT COUNT(*) as total FROM users") print(f"Total: {count['total']}")
374 async def fetch_all( 375 self, query: str, params: tuple = () 376 ) -> List[Dict[str, Any]]: 377 """ 378 全行を取得 379 380 複数行を返すSELECTクエリに使用します。 381 結果がない場合は空のリストを返します。 382 383 Args: 384 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 385 params: クエリパラメータ(タプル) 386 387 Returns: 388 List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式) 389 390 Example: 391 >>> # 全ユーザーを取得 392 >>> users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") 393 >>> for user in users: 394 ... print(f"{user['id']}: {user['name']}") 395 >>> 396 >>> # 条件付きで取得 397 >>> active_users = await db.fetch_all( 398 ... "SELECT * FROM users WHERE status = $1 AND role = $2", 399 ... ("active", "admin") 400 ... ) 401 >>> print(f"Found {len(active_users)} admin users") 402 """ 403 await self._ensure_initialized() 404 async with self.pool.acquire() as conn: 405 rows = await conn.fetch(query, *params) 406 return [dict(row) for row in rows]
全行を取得
複数行を返すSELECTクエリに使用します。 結果がない場合は空のリストを返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式)
Example:
全ユーザーを取得
users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") for user in users: ... print(f"{user['id']}: {user['name']}")
条件付きで取得
active_users = await db.fetch_all( ... "SELECT * FROM users WHERE status = $1 AND role = $2", ... ("active", "admin") ... ) print(f"Found {len(active_users)} admin users")
408 async def execute(self, query: str, params: tuple = ()) -> str: 409 """ 410 実行のみ(結果を返さないクエリ用) 411 412 INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 413 DDL文(CREATE TABLE等)の実行にも使用できます。 414 415 Args: 416 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 417 params: クエリパラメータ(タプル) 418 419 Returns: 420 str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3") 421 422 Example: 423 >>> # データ挿入 424 >>> result = await db.execute( 425 ... "INSERT INTO logs (level, message) VALUES ($1, $2)", 426 ... ("INFO", "Operation completed") 427 ... ) 428 >>> print(result) # "INSERT 0 1" 429 >>> 430 >>> # データ更新 431 >>> result = await db.execute( 432 ... "UPDATE users SET status = $1 WHERE last_login < $2", 433 ... ("inactive", datetime(2024, 1, 1)) 434 ... ) 435 >>> print(result) # "UPDATE 15" 436 >>> 437 >>> # データ削除 438 >>> result = await db.execute( 439 ... "DELETE FROM sessions WHERE expired_at < $1", 440 ... (datetime.now(),) 441 ... ) 442 >>> print(result) # "DELETE 100" 443 """ 444 await self._ensure_initialized() 445 async with self.pool.acquire() as conn: 446 return await conn.execute(query, *params)
実行のみ(結果を返さないクエリ用)
INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 DDL文(CREATE TABLE等)の実行にも使用できます。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3")
Example:
データ挿入
result = await db.execute( ... "INSERT INTO logs (level, message) VALUES ($1, $2)", ... ("INFO", "Operation completed") ... ) print(result) # "INSERT 0 1"
データ更新
result = await db.execute( ... "UPDATE users SET status = $1 WHERE last_login < $2", ... ("inactive", datetime(2024, 1, 1)) ... ) print(result) # "UPDATE 15"
データ削除
result = await db.execute( ... "DELETE FROM sessions WHERE expired_at < $1", ... (datetime.now(),) ... ) print(result) # "DELETE 100"
101@dataclass 102class PostgreSQLConfig: 103 """PostgreSQL設定 104 105 Note: 106 passwordはrepr=Falseでログ出力から除外されます。 107 108 Example: 109 >>> # 辞書から作成 110 >>> config = PostgreSQLConfig.from_dict({ 111 ... "host": "localhost", 112 ... "port": 5432, 113 ... "database": "mydb", 114 ... "username": "user", 115 ... "password": "pass" 116 ... }) 117 118 >>> # TOMLファイルから作成 119 >>> config = PostgreSQLConfig.from_toml("config.toml", section="database") 120 """ 121 122 provider: str = "postgresql" 123 host: str = "localhost" 124 port: int = 5432 125 database: str = "autonomousagent_db" 126 use_azure_ad: bool = False 127 128 # 標準認証用(use_azure_ad=Falseの場合) 129 username: Optional[str] = None 130 password: Optional[str] = field(default=None, repr=False) # シークレット: reprから除外 131 132 # 接続プール設定 133 pool_min_size: int = 5 134 pool_max_size: int = 20 135 command_timeout: int = 60 136 pool_timeout: Optional[int] = 30 137 max_overflow: Optional[int] = 10 138 139 # Azure AD認証設定 140 azure_ad: Optional[AzureADConfig] = None 141 142 # API Proxy設定(CLIモード用) 143 api_proxy_url: Optional[str] = None 144 145 # SSL設定(デフォルト: require, ローカル開発用: disable) 146 ssl_mode: str = "require" 147 148 def __str__(self) -> str: 149 return ( 150 f"PostgreSQLConfig(host={self.host!r}, port={self.port}, " 151 f"database={self.database!r}, use_azure_ad={self.use_azure_ad}, password=***)" 152 ) 153 154 @classmethod 155 def from_dict(cls, data: Dict[str, Any]) -> "PostgreSQLConfig": 156 """辞書からPostgreSQLConfigを作成 157 158 Args: 159 data: 設定辞書 160 161 Returns: 162 PostgreSQLConfig instance 163 164 Example: 165 >>> config = PostgreSQLConfig.from_dict({ 166 ... "host": "db.example.com", 167 ... "port": 5432, 168 ... "database": "mydb", 169 ... "username": "user", 170 ... "password": "secret", 171 ... "use_azure_ad": False, 172 ... "pool_max_size": 30 173 ... }) 174 """ 175 # Azure AD設定がある場合はネストして処理 176 azure_ad = None 177 if "azure_ad" in data and data.get("use_azure_ad"): 178 azure_ad = AzureADConfig.from_dict(data["azure_ad"]) 179 180 return cls( 181 provider=data.get("provider", "postgresql"), 182 host=data.get("host", "localhost"), 183 port=data.get("port", 5432), 184 database=data.get("database", "autonomousagent_db"), 185 use_azure_ad=data.get("use_azure_ad", False), 186 username=data.get("username"), 187 password=data.get("password"), 188 pool_min_size=data.get("pool_min_size", 5), 189 pool_max_size=data.get("pool_max_size", 20), 190 command_timeout=data.get("command_timeout", 60), 191 pool_timeout=data.get("pool_timeout", 30), 192 max_overflow=data.get("max_overflow", 10), 193 azure_ad=azure_ad, 194 api_proxy_url=data.get("api_proxy_url"), 195 ) 196 197 @classmethod 198 def from_toml(cls, toml_path: str, section: str = "database") -> "PostgreSQLConfig": 199 """TOMLファイルからPostgreSQLConfigを作成 200 201 Args: 202 toml_path: TOMLファイルパス 203 section: セクション名(ドット区切りでネスト対応) 204 205 Returns: 206 PostgreSQLConfig instance 207 208 Example: 209 >>> # config.toml の [database] セクションを読み込み 210 >>> config = PostgreSQLConfig.from_toml("config.toml") 211 212 >>> # ネストされたセクション 213 >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database") 214 """ 215 if not TOML_AVAILABLE: 216 raise ImportError("toml library is required: pip install toml") 217 218 with open(toml_path, "r") as f: 219 toml_data = toml.load(f) 220 221 # ネストされたセクションをドット区切りで辿る 222 data = toml_data 223 for key in section.split("."): 224 data = data[key] 225 226 return cls.from_dict(data) 227 228 @classmethod 229 def from_env(cls, prefix: str = "DB_") -> "PostgreSQLConfig": 230 """環境変数からPostgreSQLConfigを作成 231 232 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 233 234 Args: 235 prefix: 環境変数名のプレフィックス(デフォルト: "DB_") 236 237 Returns: 238 PostgreSQLConfig instance 239 240 Example: 241 >>> # 環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.) 242 >>> config = PostgreSQLConfig.from_env() 243 244 >>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) 245 >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_") 246 247 Environment Variables: 248 {prefix}HOST: データベースホスト(デフォルト: localhost) 249 {prefix}PORT: ポート番号(デフォルト: 5432) 250 {prefix}DATABASE: データベース名(デフォルト: autonomousagent_db) 251 {prefix}USER / {prefix}USERNAME: ユーザー名 252 {prefix}PASSWORD: パスワード 253 {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) 254 {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) 255 {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) 256 {prefix}API_PROXY_URL: API Proxy URL(CLIモード用) 257 """ 258 def get_env(key: str, default: str = "") -> str: 259 return os.environ.get(f"{prefix}{key}", default) 260 261 def get_env_int(key: str, default: int) -> int: 262 val = os.environ.get(f"{prefix}{key}") 263 return int(val) if val else default 264 265 def get_env_bool(key: str, default: bool = False) -> bool: 266 val = os.environ.get(f"{prefix}{key}", "").lower() 267 return val in ("true", "1", "yes") if val else default 268 269 # ユーザー名は USER または USERNAME どちらでも可 270 username = get_env("USER") or get_env("USERNAME") 271 272 return cls( 273 host=get_env("HOST", "localhost"), 274 port=get_env_int("PORT", 5432), 275 database=get_env("DATABASE", "autonomousagent_db"), 276 username=username or None, 277 password=get_env("PASSWORD") or None, 278 use_azure_ad=get_env_bool("USE_AZURE_AD"), 279 pool_min_size=get_env_int("POOL_MIN_SIZE", 5), 280 pool_max_size=get_env_int("POOL_MAX_SIZE", 20), 281 command_timeout=get_env_int("COMMAND_TIMEOUT", 60), 282 pool_timeout=get_env_int("POOL_TIMEOUT", 30), 283 max_overflow=get_env_int("MAX_OVERFLOW", 10), 284 api_proxy_url=get_env("API_PROXY_URL") or None, 285 )
PostgreSQL設定
Note: passwordはrepr=Falseでログ出力から除外されます。
Example:
辞書から作成
config = PostgreSQLConfig.from_dict({ ... "host": "localhost", ... "port": 5432, ... "database": "mydb", ... "username": "user", ... "password": "pass" ... })
>>> # TOMLファイルから作成 >>> config = PostgreSQLConfig.from_toml("config.toml", section="database")
154 @classmethod 155 def from_dict(cls, data: Dict[str, Any]) -> "PostgreSQLConfig": 156 """辞書からPostgreSQLConfigを作成 157 158 Args: 159 data: 設定辞書 160 161 Returns: 162 PostgreSQLConfig instance 163 164 Example: 165 >>> config = PostgreSQLConfig.from_dict({ 166 ... "host": "db.example.com", 167 ... "port": 5432, 168 ... "database": "mydb", 169 ... "username": "user", 170 ... "password": "secret", 171 ... "use_azure_ad": False, 172 ... "pool_max_size": 30 173 ... }) 174 """ 175 # Azure AD設定がある場合はネストして処理 176 azure_ad = None 177 if "azure_ad" in data and data.get("use_azure_ad"): 178 azure_ad = AzureADConfig.from_dict(data["azure_ad"]) 179 180 return cls( 181 provider=data.get("provider", "postgresql"), 182 host=data.get("host", "localhost"), 183 port=data.get("port", 5432), 184 database=data.get("database", "autonomousagent_db"), 185 use_azure_ad=data.get("use_azure_ad", False), 186 username=data.get("username"), 187 password=data.get("password"), 188 pool_min_size=data.get("pool_min_size", 5), 189 pool_max_size=data.get("pool_max_size", 20), 190 command_timeout=data.get("command_timeout", 60), 191 pool_timeout=data.get("pool_timeout", 30), 192 max_overflow=data.get("max_overflow", 10), 193 azure_ad=azure_ad, 194 api_proxy_url=data.get("api_proxy_url"), 195 )
辞書からPostgreSQLConfigを作成
Args: data: 設定辞書
Returns: PostgreSQLConfig instance
Example:
config = PostgreSQLConfig.from_dict({ ... "host": "db.example.com", ... "port": 5432, ... "database": "mydb", ... "username": "user", ... "password": "secret", ... "use_azure_ad": False, ... "pool_max_size": 30 ... })
197 @classmethod 198 def from_toml(cls, toml_path: str, section: str = "database") -> "PostgreSQLConfig": 199 """TOMLファイルからPostgreSQLConfigを作成 200 201 Args: 202 toml_path: TOMLファイルパス 203 section: セクション名(ドット区切りでネスト対応) 204 205 Returns: 206 PostgreSQLConfig instance 207 208 Example: 209 >>> # config.toml の [database] セクションを読み込み 210 >>> config = PostgreSQLConfig.from_toml("config.toml") 211 212 >>> # ネストされたセクション 213 >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database") 214 """ 215 if not TOML_AVAILABLE: 216 raise ImportError("toml library is required: pip install toml") 217 218 with open(toml_path, "r") as f: 219 toml_data = toml.load(f) 220 221 # ネストされたセクションをドット区切りで辿る 222 data = toml_data 223 for key in section.split("."): 224 data = data[key] 225 226 return cls.from_dict(data)
TOMLファイルからPostgreSQLConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: PostgreSQLConfig instance
Example:
config.toml の [database] セクションを読み込み
config = PostgreSQLConfig.from_toml("config.toml")
>>> # ネストされたセクション >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database")
228 @classmethod 229 def from_env(cls, prefix: str = "DB_") -> "PostgreSQLConfig": 230 """環境変数からPostgreSQLConfigを作成 231 232 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 233 234 Args: 235 prefix: 環境変数名のプレフィックス(デフォルト: "DB_") 236 237 Returns: 238 PostgreSQLConfig instance 239 240 Example: 241 >>> # 環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.) 242 >>> config = PostgreSQLConfig.from_env() 243 244 >>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) 245 >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_") 246 247 Environment Variables: 248 {prefix}HOST: データベースホスト(デフォルト: localhost) 249 {prefix}PORT: ポート番号(デフォルト: 5432) 250 {prefix}DATABASE: データベース名(デフォルト: autonomousagent_db) 251 {prefix}USER / {prefix}USERNAME: ユーザー名 252 {prefix}PASSWORD: パスワード 253 {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) 254 {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) 255 {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) 256 {prefix}API_PROXY_URL: API Proxy URL(CLIモード用) 257 """ 258 def get_env(key: str, default: str = "") -> str: 259 return os.environ.get(f"{prefix}{key}", default) 260 261 def get_env_int(key: str, default: int) -> int: 262 val = os.environ.get(f"{prefix}{key}") 263 return int(val) if val else default 264 265 def get_env_bool(key: str, default: bool = False) -> bool: 266 val = os.environ.get(f"{prefix}{key}", "").lower() 267 return val in ("true", "1", "yes") if val else default 268 269 # ユーザー名は USER または USERNAME どちらでも可 270 username = get_env("USER") or get_env("USERNAME") 271 272 return cls( 273 host=get_env("HOST", "localhost"), 274 port=get_env_int("PORT", 5432), 275 database=get_env("DATABASE", "autonomousagent_db"), 276 username=username or None, 277 password=get_env("PASSWORD") or None, 278 use_azure_ad=get_env_bool("USE_AZURE_AD"), 279 pool_min_size=get_env_int("POOL_MIN_SIZE", 5), 280 pool_max_size=get_env_int("POOL_MAX_SIZE", 20), 281 command_timeout=get_env_int("COMMAND_TIMEOUT", 60), 282 pool_timeout=get_env_int("POOL_TIMEOUT", 30), 283 max_overflow=get_env_int("MAX_OVERFLOW", 10), 284 api_proxy_url=get_env("API_PROXY_URL") or None, 285 )
環境変数からPostgreSQLConfigを作成
環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。
Args: prefix: 環境変数名のプレフィックス(デフォルト: "DB_")
Returns: PostgreSQLConfig instance
Example:
環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.)
config = PostgreSQLConfig.from_env()
>>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_")Environment Variables: {prefix}HOST: データベースホスト(デフォルト: localhost) {prefix}PORT: ポート番号(デフォルト: 5432) {prefix}DATABASE: データベース名(デフォルト: autonomousagent_db) {prefix}USER / {prefix}USERNAME: ユーザー名 {prefix}PASSWORD: パスワード {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) {prefix}API_PROXY_URL: API Proxy URL(CLIモード用)
21@dataclass 22class AzureADConfig: 23 """Azure AD認証設定 24 25 Note: 26 client_secretはrepr=Falseでログ出力から除外されます。 27 28 Example: 29 >>> # 辞書から作成 30 >>> config = AzureADConfig.from_dict({ 31 ... "tenant_id": "xxx", 32 ... "client_id": "yyy", 33 ... "client_secret": "zzz" 34 ... }) 35 36 >>> # TOMLファイルから作成 37 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 38 """ 39 40 tenant_id: str 41 client_id: str 42 client_secret: str = field(repr=False) # シークレット: reprから除外 43 username: str = "" # PostgreSQL Azure AD authentication username 44 45 def __str__(self) -> str: 46 return f"AzureADConfig(tenant_id={self.tenant_id!r}, client_id={self.client_id!r}, client_secret=***)" 47 48 @classmethod 49 def from_dict(cls, data: Dict[str, Any]) -> "AzureADConfig": 50 """辞書からAzureADConfigを作成 51 52 Args: 53 data: 設定辞書 54 55 Returns: 56 AzureADConfig instance 57 58 Example: 59 >>> config = AzureADConfig.from_dict({ 60 ... "tenant_id": "your-tenant-id", 61 ... "client_id": "your-client-id", 62 ... "client_secret": "your-secret" 63 ... }) 64 """ 65 return cls( 66 tenant_id=data["tenant_id"], 67 client_id=data["client_id"], 68 client_secret=data["client_secret"], 69 username=data.get("username", ""), 70 ) 71 72 @classmethod 73 def from_toml(cls, toml_path: str, section: str = "azure_ad") -> "AzureADConfig": 74 """TOMLファイルからAzureADConfigを作成 75 76 Args: 77 toml_path: TOMLファイルパス 78 section: セクション名(ドット区切りでネスト対応) 79 80 Returns: 81 AzureADConfig instance 82 83 Example: 84 >>> # config.toml の [database.azure_ad] セクションを読み込み 85 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 86 """ 87 if not TOML_AVAILABLE: 88 raise ImportError("toml library is required: pip install toml") 89 90 with open(toml_path, "r") as f: 91 toml_data = toml.load(f) 92 93 # ネストされたセクションをドット区切りで辿る 94 data = toml_data 95 for key in section.split("."): 96 data = data[key] 97 98 return cls.from_dict(data)
Azure AD認証設定
Note: client_secretはrepr=Falseでログ出力から除外されます。
Example:
辞書から作成
config = AzureADConfig.from_dict({ ... "tenant_id": "xxx", ... "client_id": "yyy", ... "client_secret": "zzz" ... })
>>> # TOMLファイルから作成 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad")
48 @classmethod 49 def from_dict(cls, data: Dict[str, Any]) -> "AzureADConfig": 50 """辞書からAzureADConfigを作成 51 52 Args: 53 data: 設定辞書 54 55 Returns: 56 AzureADConfig instance 57 58 Example: 59 >>> config = AzureADConfig.from_dict({ 60 ... "tenant_id": "your-tenant-id", 61 ... "client_id": "your-client-id", 62 ... "client_secret": "your-secret" 63 ... }) 64 """ 65 return cls( 66 tenant_id=data["tenant_id"], 67 client_id=data["client_id"], 68 client_secret=data["client_secret"], 69 username=data.get("username", ""), 70 )
辞書からAzureADConfigを作成
Args: data: 設定辞書
Returns: AzureADConfig instance
Example:
config = AzureADConfig.from_dict({ ... "tenant_id": "your-tenant-id", ... "client_id": "your-client-id", ... "client_secret": "your-secret" ... })
72 @classmethod 73 def from_toml(cls, toml_path: str, section: str = "azure_ad") -> "AzureADConfig": 74 """TOMLファイルからAzureADConfigを作成 75 76 Args: 77 toml_path: TOMLファイルパス 78 section: セクション名(ドット区切りでネスト対応) 79 80 Returns: 81 AzureADConfig instance 82 83 Example: 84 >>> # config.toml の [database.azure_ad] セクションを読み込み 85 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 86 """ 87 if not TOML_AVAILABLE: 88 raise ImportError("toml library is required: pip install toml") 89 90 with open(toml_path, "r") as f: 91 toml_data = toml.load(f) 92 93 # ネストされたセクションをドット区切りで辿る 94 data = toml_data 95 for key in section.split("."): 96 data = data[key] 97 98 return cls.from_dict(data)
TOMLファイルからAzureADConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: AzureADConfig instance
Example:
config.toml の [database.azure_ad] セクションを読み込み
config = AzureADConfig.from_toml("config.toml", section="database.azure_ad")
45class DataAccess: 46 """ 47 汎用データアクセスクラス 48 49 テーブル操作の基本機能を提供: 50 - insert: 単一レコード挿入 51 - upsert: 挿入または更新 52 - select: 条件付き検索 53 - update: 条件付き更新 54 - delete: 条件付き削除 55 - execute_query: 生SQLクエリ実行 56 57 Security: 58 全てのテーブル名・カラム名は _validate_identifier() で検証されます。 59 不正な識別子が渡された場合は SQLInjectionError が発生します。 60 61 Example: 62 async with DataAccess(config) as da: 63 result = await da.select("users", where={"id": 1}) 64 """ 65 66 def __init__( 67 self, 68 config: PostgreSQLConfig, 69 *, 70 use_proxy: bool = False, 71 token_provider: Optional[Callable[[], Optional[str]]] = None, 72 ): 73 """ 74 Args: 75 config: PostgreSQL設定(位置引数、後方互換) 76 use_proxy: API Proxy経由でアクセスするか(デフォルト: False) 77 token_provider: Proxy使用時の認証トークン取得関数 78 79 Note: 80 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 81 既存コード(config引数のみ)は引き続き動作する。 82 """ 83 self.db = create_postgresql_manager( 84 config, 85 use_proxy=use_proxy, 86 token_provider=token_provider, 87 ) 88 self.logger = logger 89 90 async def __aenter__(self) -> "DataAccess": 91 """Context Manager: 開始時にデータベース接続を初期化""" 92 await self.initialize() 93 return self 94 95 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 96 """Context Manager: 終了時にデータベース接続をクローズ""" 97 await self.close() 98 99 def _validate_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: 100 """ 101 SQL識別子(テーブル名・カラム名)を検証 102 103 Args: 104 identifier: 検証する識別子 105 identifier_type: エラーメッセージ用の識別子タイプ 106 107 Returns: 108 検証済みの識別子 109 110 Raises: 111 SQLInjectionError: 不正な識別子の場合 112 """ 113 if not identifier or not _IDENTIFIER_PATTERN.match(identifier): 114 raise SQLInjectionError(identifier, identifier_type) 115 return identifier 116 117 def _validate_identifiers(self, identifiers: List[str], identifier_type: str = "column") -> List[str]: 118 """ 119 複数のSQL識別子を検証 120 121 Args: 122 identifiers: 検証する識別子のリスト 123 identifier_type: エラーメッセージ用の識別子タイプ 124 125 Returns: 126 検証済みの識別子リスト 127 128 Raises: 129 SQLInjectionError: 不正な識別子が含まれる場合 130 """ 131 return [self._validate_identifier(ident, identifier_type) for ident in identifiers] 132 133 def _validate_order_by(self, order_by: str) -> str: 134 """ 135 ORDER BY句を検証 136 137 Args: 138 order_by: ORDER BY句(例: "created_at DESC") 139 140 Returns: 141 検証済みのORDER BY句 142 143 Raises: 144 SQLInjectionError: 不正なORDER BY句の場合 145 """ 146 if not order_by or not _ORDER_BY_PATTERN.match(order_by.strip()): 147 raise SQLInjectionError(order_by, "ORDER BY clause") 148 return order_by 149 150 async def initialize(self) -> None: 151 """データベース接続を初期化""" 152 await self.db.initialize() 153 154 def is_initialized(self) -> bool: 155 """データベース接続が初期化されているか確認 156 157 Note: 158 PostgreSQLManager: pool属性で判定 159 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続) 160 """ 161 if hasattr(self.db, 'pool'): 162 return self.db.pool is not None 163 # ApiPostgreSQLManager等、pool属性を持たないマネージャは常に初期化済み扱い 164 return True 165 166 async def ensure_initialized(self) -> None: 167 """データベース接続が初期化されていなければ初期化""" 168 if not self.is_initialized(): 169 await self.initialize() 170 171 async def close(self) -> None: 172 """データベース接続を閉じる""" 173 await self.db.close() 174 175 async def execute_query( 176 self, query: str, params: tuple = () 177 ) -> Dict[str, Any]: 178 """ 179 生SQLクエリを実行 180 181 Args: 182 query: SQLクエリ 183 params: クエリパラメータ 184 185 Returns: 186 Dict with keys: success, data, error (optional) 187 """ 188 return await self.db.execute_query(query, params) 189 190 async def insert( 191 self, 192 table: str, 193 data: Dict[str, Any], 194 returning: Optional[List[str]] = None, 195 ) -> Dict[str, Any]: 196 """ 197 単一レコードを挿入 198 199 Args: 200 table: テーブル名 201 data: 挿入するデータ {column: value} 202 returning: 返却するカラム名のリスト 203 204 Returns: 205 Dict with keys: success, data, error (optional) 206 207 Raises: 208 SQLInjectionError: 不正なテーブル名・カラム名の場合 209 """ 210 # SQL Injection対策: 識別子を検証 211 safe_table = self._validate_identifier(table, "table") 212 columns = list(data.keys()) 213 safe_columns = self._validate_identifiers(columns, "column") 214 215 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 216 values = [self._serialize_value(v) for v in data.values()] 217 218 query = f"INSERT INTO {safe_table} ({', '.join(safe_columns)}) VALUES ({', '.join(placeholders)})" 219 220 if returning: 221 safe_returning = self._validate_identifiers(returning, "returning column") 222 query += f" RETURNING {', '.join(safe_returning)}" 223 224 return await self.db.execute_query(query, tuple(values)) 225 226 async def upsert( 227 self, 228 table: str, 229 data: Dict[str, Any], 230 conflict_columns: List[str], 231 update_columns: Optional[List[str]] = None, 232 returning: Optional[List[str]] = None, 233 ) -> Dict[str, Any]: 234 """ 235 挿入または更新(UPSERT) 236 237 Args: 238 table: テーブル名 239 data: 挿入/更新するデータ {column: value} 240 conflict_columns: 競合判定に使うカラム 241 update_columns: 競合時に更新するカラム(Noneの場合は全カラム) 242 returning: 返却するカラム名のリスト 243 244 Returns: 245 Dict with keys: success, data, error (optional) 246 247 Raises: 248 SQLInjectionError: 不正なテーブル名・カラム名の場合 249 """ 250 # SQL Injection対策: 識別子を検証 251 safe_table = self._validate_identifier(table, "table") 252 columns = list(data.keys()) 253 safe_columns = self._validate_identifiers(columns, "column") 254 safe_conflict_columns = self._validate_identifiers(conflict_columns, "conflict column") 255 256 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 257 values = [self._serialize_value(v) for v in data.values()] 258 259 # 更新対象カラムを決定 260 if update_columns is None: 261 update_columns = [c for c in columns if c not in conflict_columns] 262 safe_update_columns = self._validate_identifiers(update_columns, "update column") 263 264 # UPDATE SET句を構築 265 update_set = ", ".join( 266 [f"{col} = EXCLUDED.{col}" for col in safe_update_columns] 267 ) 268 269 query = f""" 270 INSERT INTO {safe_table} ({', '.join(safe_columns)}) 271 VALUES ({', '.join(placeholders)}) 272 ON CONFLICT ({', '.join(safe_conflict_columns)}) 273 DO UPDATE SET {update_set} 274 """ 275 276 if returning: 277 safe_returning = self._validate_identifiers(returning, "returning column") 278 query += f" RETURNING {', '.join(safe_returning)}" 279 280 return await self.db.execute_query(query, tuple(values)) 281 282 async def select( 283 self, 284 table: str, 285 columns: Optional[List[str]] = None, 286 where: Optional[Dict[str, Any]] = None, 287 order_by: Optional[str] = None, 288 limit: Optional[int] = None, 289 ) -> Dict[str, Any]: 290 """ 291 条件付き検索 292 293 Args: 294 table: テーブル名 295 columns: 取得するカラム(Noneの場合は全カラム) 296 where: 検索条件 {column: value} 297 order_by: ソート指定(例: "created_at DESC") 298 limit: 取得件数上限 299 300 Returns: 301 Dict with keys: success, data (List[Dict]), error (optional) 302 303 Raises: 304 SQLInjectionError: 不正なテーブル名・カラム名の場合 305 """ 306 # SQL Injection対策: 識別子を検証 307 safe_table = self._validate_identifier(table, "table") 308 309 if columns is None: 310 select_cols = "*" 311 else: 312 safe_columns = self._validate_identifiers(columns, "column") 313 select_cols = ", ".join(safe_columns) 314 315 query = f"SELECT {select_cols} FROM {safe_table}" 316 params: List[Any] = [] 317 318 if where: 319 conditions = [] 320 param_index = 1 321 for col, val in where.items(): 322 safe_col = self._validate_identifier(col, "where column") 323 if val is None: 324 conditions.append(f"{safe_col} IS NULL") 325 else: 326 conditions.append(f"{safe_col} = ${param_index}") 327 params.append(self._serialize_value(val)) 328 param_index += 1 329 query += f" WHERE {' AND '.join(conditions)}" 330 331 if order_by: 332 safe_order_by = self._validate_order_by(order_by) 333 query += f" ORDER BY {safe_order_by}" 334 335 if limit is not None: 336 if not isinstance(limit, int) or limit < 0: 337 raise ValueError(f"Invalid limit value: {limit}") 338 query += f" LIMIT {limit}" 339 340 return await self.db.execute_query(query, tuple(params)) 341 342 async def select_one( 343 self, 344 table: str, 345 columns: Optional[List[str]] = None, 346 where: Optional[Dict[str, Any]] = None, 347 ) -> Optional[Dict[str, Any]]: 348 """ 349 単一レコードを取得 350 351 Args: 352 table: テーブル名 353 columns: 取得するカラム 354 where: 検索条件 355 356 Returns: 357 Dict (結果がある場合) or None 358 """ 359 result = await self.select(table, columns, where, limit=1) 360 if result.get("success") and result.get("data"): 361 return result["data"][0] 362 return None 363 364 async def update( 365 self, 366 table: str, 367 data: Dict[str, Any], 368 where: Dict[str, Any], 369 returning: Optional[List[str]] = None, 370 ) -> Dict[str, Any]: 371 """ 372 条件付き更新 373 374 Args: 375 table: テーブル名 376 data: 更新するデータ {column: value} 377 where: 更新条件 {column: value} 378 returning: 返却するカラム名のリスト 379 380 Returns: 381 Dict with keys: success, data, error (optional) 382 383 Raises: 384 SQLInjectionError: 不正なテーブル名・カラム名の場合 385 """ 386 # SQL Injection対策: 識別子を検証 387 safe_table = self._validate_identifier(table, "table") 388 389 # SET句を構築 390 set_columns = list(data.keys()) 391 safe_set_columns = self._validate_identifiers(set_columns, "set column") 392 set_clause = ", ".join([f"{col} = ${i+1}" for i, col in enumerate(safe_set_columns)]) 393 params: List[Any] = [self._serialize_value(v) for v in data.values()] 394 395 # WHERE句を構築 396 where_columns = list(where.keys()) 397 safe_where_columns = self._validate_identifiers(where_columns, "where column") 398 where_clause = " AND ".join( 399 [f"{col} = ${len(safe_set_columns) + i + 1}" for i, col in enumerate(safe_where_columns)] 400 ) 401 params.extend([self._serialize_value(v) for v in where.values()]) 402 403 query = f"UPDATE {safe_table} SET {set_clause} WHERE {where_clause}" 404 405 if returning: 406 safe_returning = self._validate_identifiers(returning, "returning column") 407 query += f" RETURNING {', '.join(safe_returning)}" 408 409 return await self.db.execute_query(query, tuple(params)) 410 411 async def delete( 412 self, 413 table: str, 414 where: Dict[str, Any], 415 returning: Optional[List[str]] = None, 416 ) -> Dict[str, Any]: 417 """ 418 条件付き削除 419 420 Args: 421 table: テーブル名 422 where: 削除条件 {column: value} 423 returning: 返却するカラム名のリスト 424 425 Returns: 426 Dict with keys: success, data, error (optional) 427 428 Raises: 429 SQLInjectionError: 不正なテーブル名・カラム名の場合 430 """ 431 # SQL Injection対策: 識別子を検証 432 safe_table = self._validate_identifier(table, "table") 433 434 # WHERE句を構築 435 conditions = [] 436 params: List[Any] = [] 437 param_index = 1 438 for col, val in where.items(): 439 safe_col = self._validate_identifier(col, "where column") 440 if val is None: 441 conditions.append(f"{safe_col} IS NULL") 442 else: 443 conditions.append(f"{safe_col} = ${param_index}") 444 params.append(self._serialize_value(val)) 445 param_index += 1 446 447 query = f"DELETE FROM {safe_table} WHERE {' AND '.join(conditions)}" 448 449 if returning: 450 safe_returning = self._validate_identifiers(returning, "returning column") 451 query += f" RETURNING {', '.join(safe_returning)}" 452 453 return await self.db.execute_query(query, tuple(params)) 454 455 def _serialize_value(self, value: Any) -> Any: 456 """ 457 値をDB保存用にシリアライズ 458 459 Args: 460 value: シリアライズする値 461 462 Returns: 463 シリアライズされた値 464 """ 465 if isinstance(value, dict) or isinstance(value, list): 466 return json.dumps(value, ensure_ascii=False) 467 return value
汎用データアクセスクラス
テーブル操作の基本機能を提供:
- insert: 単一レコード挿入
- upsert: 挿入または更新
- select: 条件付き検索
- update: 条件付き更新
- delete: 条件付き削除
- execute_query: 生SQLクエリ実行
Security: 全てのテーブル名・カラム名は _validate_identifier() で検証されます。 不正な識別子が渡された場合は SQLInjectionError が発生します。
Example: async with DataAccess(config) as da: result = await da.select("users", where={"id": 1})
66 def __init__( 67 self, 68 config: PostgreSQLConfig, 69 *, 70 use_proxy: bool = False, 71 token_provider: Optional[Callable[[], Optional[str]]] = None, 72 ): 73 """ 74 Args: 75 config: PostgreSQL設定(位置引数、後方互換) 76 use_proxy: API Proxy経由でアクセスするか(デフォルト: False) 77 token_provider: Proxy使用時の認証トークン取得関数 78 79 Note: 80 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 81 既存コード(config引数のみ)は引き続き動作する。 82 """ 83 self.db = create_postgresql_manager( 84 config, 85 use_proxy=use_proxy, 86 token_provider=token_provider, 87 ) 88 self.logger = logger
Args: config: PostgreSQL設定(位置引数、後方互換) use_proxy: API Proxy経由でアクセスするか(デフォルト: False) token_provider: Proxy使用時の認証トークン取得関数
Note: 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 既存コード(config引数のみ)は引き続き動作する。
154 def is_initialized(self) -> bool: 155 """データベース接続が初期化されているか確認 156 157 Note: 158 PostgreSQLManager: pool属性で判定 159 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続) 160 """ 161 if hasattr(self.db, 'pool'): 162 return self.db.pool is not None 163 # ApiPostgreSQLManager等、pool属性を持たないマネージャは常に初期化済み扱い 164 return True
データベース接続が初期化されているか確認
Note: PostgreSQLManager: pool属性で判定 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続)
166 async def ensure_initialized(self) -> None: 167 """データベース接続が初期化されていなければ初期化""" 168 if not self.is_initialized(): 169 await self.initialize()
データベース接続が初期化されていなければ初期化
175 async def execute_query( 176 self, query: str, params: tuple = () 177 ) -> Dict[str, Any]: 178 """ 179 生SQLクエリを実行 180 181 Args: 182 query: SQLクエリ 183 params: クエリパラメータ 184 185 Returns: 186 Dict with keys: success, data, error (optional) 187 """ 188 return await self.db.execute_query(query, params)
生SQLクエリを実行
Args: query: SQLクエリ params: クエリパラメータ
Returns: Dict with keys: success, data, error (optional)
190 async def insert( 191 self, 192 table: str, 193 data: Dict[str, Any], 194 returning: Optional[List[str]] = None, 195 ) -> Dict[str, Any]: 196 """ 197 単一レコードを挿入 198 199 Args: 200 table: テーブル名 201 data: 挿入するデータ {column: value} 202 returning: 返却するカラム名のリスト 203 204 Returns: 205 Dict with keys: success, data, error (optional) 206 207 Raises: 208 SQLInjectionError: 不正なテーブル名・カラム名の場合 209 """ 210 # SQL Injection対策: 識別子を検証 211 safe_table = self._validate_identifier(table, "table") 212 columns = list(data.keys()) 213 safe_columns = self._validate_identifiers(columns, "column") 214 215 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 216 values = [self._serialize_value(v) for v in data.values()] 217 218 query = f"INSERT INTO {safe_table} ({', '.join(safe_columns)}) VALUES ({', '.join(placeholders)})" 219 220 if returning: 221 safe_returning = self._validate_identifiers(returning, "returning column") 222 query += f" RETURNING {', '.join(safe_returning)}" 223 224 return await self.db.execute_query(query, tuple(values))
単一レコードを挿入
Args: table: テーブル名 data: 挿入するデータ {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
226 async def upsert( 227 self, 228 table: str, 229 data: Dict[str, Any], 230 conflict_columns: List[str], 231 update_columns: Optional[List[str]] = None, 232 returning: Optional[List[str]] = None, 233 ) -> Dict[str, Any]: 234 """ 235 挿入または更新(UPSERT) 236 237 Args: 238 table: テーブル名 239 data: 挿入/更新するデータ {column: value} 240 conflict_columns: 競合判定に使うカラム 241 update_columns: 競合時に更新するカラム(Noneの場合は全カラム) 242 returning: 返却するカラム名のリスト 243 244 Returns: 245 Dict with keys: success, data, error (optional) 246 247 Raises: 248 SQLInjectionError: 不正なテーブル名・カラム名の場合 249 """ 250 # SQL Injection対策: 識別子を検証 251 safe_table = self._validate_identifier(table, "table") 252 columns = list(data.keys()) 253 safe_columns = self._validate_identifiers(columns, "column") 254 safe_conflict_columns = self._validate_identifiers(conflict_columns, "conflict column") 255 256 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 257 values = [self._serialize_value(v) for v in data.values()] 258 259 # 更新対象カラムを決定 260 if update_columns is None: 261 update_columns = [c for c in columns if c not in conflict_columns] 262 safe_update_columns = self._validate_identifiers(update_columns, "update column") 263 264 # UPDATE SET句を構築 265 update_set = ", ".join( 266 [f"{col} = EXCLUDED.{col}" for col in safe_update_columns] 267 ) 268 269 query = f""" 270 INSERT INTO {safe_table} ({', '.join(safe_columns)}) 271 VALUES ({', '.join(placeholders)}) 272 ON CONFLICT ({', '.join(safe_conflict_columns)}) 273 DO UPDATE SET {update_set} 274 """ 275 276 if returning: 277 safe_returning = self._validate_identifiers(returning, "returning column") 278 query += f" RETURNING {', '.join(safe_returning)}" 279 280 return await self.db.execute_query(query, tuple(values))
挿入または更新(UPSERT)
Args: table: テーブル名 data: 挿入/更新するデータ {column: value} conflict_columns: 競合判定に使うカラム update_columns: 競合時に更新するカラム(Noneの場合は全カラム) returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
282 async def select( 283 self, 284 table: str, 285 columns: Optional[List[str]] = None, 286 where: Optional[Dict[str, Any]] = None, 287 order_by: Optional[str] = None, 288 limit: Optional[int] = None, 289 ) -> Dict[str, Any]: 290 """ 291 条件付き検索 292 293 Args: 294 table: テーブル名 295 columns: 取得するカラム(Noneの場合は全カラム) 296 where: 検索条件 {column: value} 297 order_by: ソート指定(例: "created_at DESC") 298 limit: 取得件数上限 299 300 Returns: 301 Dict with keys: success, data (List[Dict]), error (optional) 302 303 Raises: 304 SQLInjectionError: 不正なテーブル名・カラム名の場合 305 """ 306 # SQL Injection対策: 識別子を検証 307 safe_table = self._validate_identifier(table, "table") 308 309 if columns is None: 310 select_cols = "*" 311 else: 312 safe_columns = self._validate_identifiers(columns, "column") 313 select_cols = ", ".join(safe_columns) 314 315 query = f"SELECT {select_cols} FROM {safe_table}" 316 params: List[Any] = [] 317 318 if where: 319 conditions = [] 320 param_index = 1 321 for col, val in where.items(): 322 safe_col = self._validate_identifier(col, "where column") 323 if val is None: 324 conditions.append(f"{safe_col} IS NULL") 325 else: 326 conditions.append(f"{safe_col} = ${param_index}") 327 params.append(self._serialize_value(val)) 328 param_index += 1 329 query += f" WHERE {' AND '.join(conditions)}" 330 331 if order_by: 332 safe_order_by = self._validate_order_by(order_by) 333 query += f" ORDER BY {safe_order_by}" 334 335 if limit is not None: 336 if not isinstance(limit, int) or limit < 0: 337 raise ValueError(f"Invalid limit value: {limit}") 338 query += f" LIMIT {limit}" 339 340 return await self.db.execute_query(query, tuple(params))
条件付き検索
Args: table: テーブル名 columns: 取得するカラム(Noneの場合は全カラム) where: 検索条件 {column: value} order_by: ソート指定(例: "created_at DESC") limit: 取得件数上限
Returns: Dict with keys: success, data (List[Dict]), error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
342 async def select_one( 343 self, 344 table: str, 345 columns: Optional[List[str]] = None, 346 where: Optional[Dict[str, Any]] = None, 347 ) -> Optional[Dict[str, Any]]: 348 """ 349 単一レコードを取得 350 351 Args: 352 table: テーブル名 353 columns: 取得するカラム 354 where: 検索条件 355 356 Returns: 357 Dict (結果がある場合) or None 358 """ 359 result = await self.select(table, columns, where, limit=1) 360 if result.get("success") and result.get("data"): 361 return result["data"][0] 362 return None
単一レコードを取得
Args: table: テーブル名 columns: 取得するカラム where: 検索条件
Returns: Dict (結果がある場合) or None
364 async def update( 365 self, 366 table: str, 367 data: Dict[str, Any], 368 where: Dict[str, Any], 369 returning: Optional[List[str]] = None, 370 ) -> Dict[str, Any]: 371 """ 372 条件付き更新 373 374 Args: 375 table: テーブル名 376 data: 更新するデータ {column: value} 377 where: 更新条件 {column: value} 378 returning: 返却するカラム名のリスト 379 380 Returns: 381 Dict with keys: success, data, error (optional) 382 383 Raises: 384 SQLInjectionError: 不正なテーブル名・カラム名の場合 385 """ 386 # SQL Injection対策: 識別子を検証 387 safe_table = self._validate_identifier(table, "table") 388 389 # SET句を構築 390 set_columns = list(data.keys()) 391 safe_set_columns = self._validate_identifiers(set_columns, "set column") 392 set_clause = ", ".join([f"{col} = ${i+1}" for i, col in enumerate(safe_set_columns)]) 393 params: List[Any] = [self._serialize_value(v) for v in data.values()] 394 395 # WHERE句を構築 396 where_columns = list(where.keys()) 397 safe_where_columns = self._validate_identifiers(where_columns, "where column") 398 where_clause = " AND ".join( 399 [f"{col} = ${len(safe_set_columns) + i + 1}" for i, col in enumerate(safe_where_columns)] 400 ) 401 params.extend([self._serialize_value(v) for v in where.values()]) 402 403 query = f"UPDATE {safe_table} SET {set_clause} WHERE {where_clause}" 404 405 if returning: 406 safe_returning = self._validate_identifiers(returning, "returning column") 407 query += f" RETURNING {', '.join(safe_returning)}" 408 409 return await self.db.execute_query(query, tuple(params))
条件付き更新
Args: table: テーブル名 data: 更新するデータ {column: value} where: 更新条件 {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
411 async def delete( 412 self, 413 table: str, 414 where: Dict[str, Any], 415 returning: Optional[List[str]] = None, 416 ) -> Dict[str, Any]: 417 """ 418 条件付き削除 419 420 Args: 421 table: テーブル名 422 where: 削除条件 {column: value} 423 returning: 返却するカラム名のリスト 424 425 Returns: 426 Dict with keys: success, data, error (optional) 427 428 Raises: 429 SQLInjectionError: 不正なテーブル名・カラム名の場合 430 """ 431 # SQL Injection対策: 識別子を検証 432 safe_table = self._validate_identifier(table, "table") 433 434 # WHERE句を構築 435 conditions = [] 436 params: List[Any] = [] 437 param_index = 1 438 for col, val in where.items(): 439 safe_col = self._validate_identifier(col, "where column") 440 if val is None: 441 conditions.append(f"{safe_col} IS NULL") 442 else: 443 conditions.append(f"{safe_col} = ${param_index}") 444 params.append(self._serialize_value(val)) 445 param_index += 1 446 447 query = f"DELETE FROM {safe_table} WHERE {' AND '.join(conditions)}" 448 449 if returning: 450 safe_returning = self._validate_identifiers(returning, "returning column") 451 query += f" RETURNING {', '.join(safe_returning)}" 452 453 return await self.db.execute_query(query, tuple(params))
条件付き削除
Args: table: テーブル名 where: 削除条件 {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
23class ConfigAccess: 24 """ 25 設定データアクセスクラス 26 27 各種設定テーブルからの読み込み機能を提供: 28 - エージェント指示・設定 29 - ツール設定 30 - ガードレール設定 31 - システム設定 32 - MCP設定 33 """ 34 35 def __init__(self, db: DataAccess): 36 """ 37 Args: 38 db: DataAccessインスタンス 39 """ 40 self._db = db 41 self.logger = logger 42 43 def _is_cli_mode(self, mode: Optional[str] = None) -> bool: 44 """CLIモードかどうかを判定""" 45 effective_mode = mode or os.environ.get("EXECUTION_MODE", "pod") 46 return effective_mode in CLI_MODES or effective_mode.startswith("cli") 47 48 # ===================================================== 49 # AGENT INSTRUCTIONS 50 # ===================================================== 51 52 async def get_agent_instructions(self, agent_type: str) -> Dict[str, Any]: 53 """ 54 エージェント指示を取得 55 56 Args: 57 agent_type: エージェントタイプ ('intent', 'react', 'coordinator') 58 59 Returns: 60 Dict with keys: success, data (instructions string), error (optional) 61 """ 62 self.logger.info(f"[IN] get_agent_instructions: agent_type={agent_type}") 63 try: 64 result = await self._db.select( 65 table="agent_instructions", 66 columns=["instructions"], 67 where={"agent_type": agent_type}, 68 ) 69 70 # DBエラーを先にチェック 71 if not result.get("success"): 72 self.logger.error(f"[ERROR] get_agent_instructions: DB error for {agent_type}") 73 return { 74 "success": False, 75 "error": result.get("error", "Database query failed"), 76 "error_code": result.get("error_code", "DATABASE_ERROR") 77 } 78 79 # データが存在するかチェック 80 if result.get("data"): 81 instructions = result["data"][0]["instructions"] 82 self.logger.info(f"[OUT] get_agent_instructions: Found for {agent_type}") 83 return {"success": True, "data": instructions} 84 else: 85 self.logger.warning(f"[OUT] get_agent_instructions: Not found for {agent_type}") 86 return { 87 "success": False, 88 "error": f"No instructions found for agent type: {agent_type}", 89 "error_code": "INSTRUCTIONS_NOT_FOUND" 90 } 91 92 except Exception as e: 93 self.logger.error(f"[ERROR] get_agent_instructions: {str(e)}") 94 return { 95 "success": False, 96 "error": f"Database error: {str(e)}", 97 "error_code": "DATABASE_ERROR" 98 } 99 100 # ===================================================== 101 # AGENT CONFIGURATIONS 102 # ===================================================== 103 104 async def get_all_agent_configs(self) -> Dict[str, Any]: 105 """ 106 全エージェント設定を一括取得 107 108 Returns: 109 Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional) 110 """ 111 self.logger.info("[IN] get_all_agent_configs") 112 try: 113 query = """ 114 SELECT agent_type, config, config_level 115 FROM agent_configurations 116 WHERE is_active = TRUE 117 ORDER BY agent_type, config_level 118 """ 119 120 result = await self._db.execute_query(query) 121 122 # DBエラーを先にチェック 123 if not result.get("success"): 124 self.logger.error("[ERROR] get_all_agent_configs: DB error") 125 return { 126 "success": False, 127 "error": result.get("error", "Database query failed"), 128 "error_code": result.get("error_code", "DATABASE_ERROR") 129 } 130 131 configs = {} 132 if result.get("data"): 133 for row in result["data"]: 134 agent_type = row['agent_type'] 135 config_level = row['config_level'] 136 137 if agent_type not in configs: 138 configs[agent_type] = {} 139 140 configs[agent_type][config_level] = row['config'] 141 142 self.logger.info(f"[OUT] get_all_agent_configs: Found {len(configs)} types") 143 else: 144 self.logger.info("[OUT] get_all_agent_configs: No configs found (empty result)") 145 146 return {"success": True, "data": configs} 147 148 except Exception as e: 149 self.logger.error(f"[ERROR] get_all_agent_configs: {str(e)}") 150 return { 151 "success": False, 152 "error": f"Failed to get agent configs: {str(e)}", 153 "error_code": "CONFIG_FETCH_ERROR" 154 } 155 156 async def get_agent_config( 157 self, 158 agent_type: str, 159 requested_level: str = "default" 160 ) -> Dict[str, Any]: 161 """ 162 エージェント設定を取得(優先順位: user > requested_level > default) 163 164 Args: 165 agent_type: エージェントタイプ 166 requested_level: リクエストレベル ('default', 'high_performance') 167 168 Returns: 169 Dict with keys: success, data ({config, level}), error (optional) 170 """ 171 self.logger.info(f"[IN] get_agent_config: agent_type={agent_type}, level={requested_level}") 172 try: 173 # まずユーザー設定を確認 174 user_config_query = """ 175 SELECT config FROM agent_configurations 176 WHERE config_level = 'user' AND agent_type = $1 AND is_active = TRUE 177 """ 178 result = await self._db.execute_query(user_config_query, (agent_type,)) 179 180 # DBエラーを先にチェック 181 if not result.get("success"): 182 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 183 return { 184 "success": False, 185 "error": result.get("error", "Database query failed"), 186 "error_code": result.get("error_code", "DATABASE_ERROR") 187 } 188 189 if result.get("data"): 190 self.logger.info(f"[OUT] get_agent_config: Found user config for {agent_type}") 191 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "user"}} 192 193 # 次に要求されたレベルの設定を確認 194 level_config_query = """ 195 SELECT config FROM agent_configurations 196 WHERE config_level = $1 AND agent_type = $2 AND is_active = TRUE 197 """ 198 result = await self._db.execute_query(level_config_query, (requested_level, agent_type)) 199 200 if not result.get("success"): 201 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 202 return { 203 "success": False, 204 "error": result.get("error", "Database query failed"), 205 "error_code": result.get("error_code", "DATABASE_ERROR") 206 } 207 208 if result.get("data"): 209 self.logger.info(f"[OUT] get_agent_config: Found {requested_level} config for {agent_type}") 210 return {"success": True, "data": {"config": result["data"][0]['config'], "level": requested_level}} 211 212 # 最後にデフォルト設定を確認 213 if requested_level != "default": 214 result = await self._db.execute_query(level_config_query, ("default", agent_type)) 215 216 if not result.get("success"): 217 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 218 return { 219 "success": False, 220 "error": result.get("error", "Database query failed"), 221 "error_code": result.get("error_code", "DATABASE_ERROR") 222 } 223 224 if result.get("data"): 225 self.logger.info(f"[OUT] get_agent_config: Found default config for {agent_type}") 226 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "default"}} 227 228 self.logger.warning(f"[OUT] get_agent_config: No config found for {agent_type}") 229 return { 230 "success": False, 231 "error": f"No configuration found for agent: {agent_type}", 232 "error_code": "CONFIG_NOT_FOUND" 233 } 234 235 except Exception as e: 236 self.logger.error(f"[ERROR] get_agent_config: {str(e)}") 237 return { 238 "success": False, 239 "error": str(e), 240 "error_code": "DATABASE_ERROR" 241 } 242 243 # ===================================================== 244 # TOOL CONFIGURATIONS 245 # ===================================================== 246 247 async def get_tool_config(self, tool_name: str) -> Dict[str, Any]: 248 """ 249 ツール設定を取得 250 251 Args: 252 tool_name: ツール名 253 254 Returns: 255 Dict with keys: success, data ({config, tool_name}), error (optional) 256 """ 257 self.logger.info(f"[IN] get_tool_config: tool_name={tool_name}") 258 try: 259 result = await self._db.select( 260 table="tool_configurations", 261 columns=["config"], 262 where={"tool_name": tool_name, "is_active": True}, 263 ) 264 265 # DBエラーを先にチェック 266 if not result.get("success"): 267 self.logger.error(f"[ERROR] get_tool_config: DB error for {tool_name}") 268 return { 269 "success": False, 270 "error": result.get("error", "Database query failed"), 271 "error_code": result.get("error_code", "DATABASE_ERROR") 272 } 273 274 if result.get("data"): 275 self.logger.info(f"[OUT] get_tool_config: Found for {tool_name}") 276 return {"success": True, "data": {"config": result["data"][0]['config'], "tool_name": tool_name}} 277 278 self.logger.warning(f"[OUT] get_tool_config: Not found for {tool_name}") 279 return { 280 "success": False, 281 "error": f"No configuration found for tool: {tool_name}", 282 "error_code": "CONFIG_NOT_FOUND" 283 } 284 285 except Exception as e: 286 self.logger.error(f"[ERROR] get_tool_config: {str(e)}") 287 return { 288 "success": False, 289 "error": str(e), 290 "error_code": "DATABASE_ERROR" 291 } 292 293 async def get_all_tool_configs(self) -> Dict[str, Any]: 294 """ 295 全ツール設定を一括取得 296 297 Returns: 298 Dict with keys: success, data ({tool_name: config}), error (optional) 299 """ 300 self.logger.info("[IN] get_all_tool_configs") 301 try: 302 query = """ 303 SELECT tool_name, config 304 FROM tool_configurations 305 WHERE is_active = TRUE 306 ORDER BY tool_name 307 """ 308 309 result = await self._db.execute_query(query) 310 311 # DBエラーを先にチェック 312 if not result.get("success"): 313 self.logger.error("[ERROR] get_all_tool_configs: DB error") 314 return { 315 "success": False, 316 "error": result.get("error", "Database query failed"), 317 "error_code": result.get("error_code", "DATABASE_ERROR") 318 } 319 320 configs = {} 321 if result.get("data"): 322 for row in result["data"]: 323 configs[row['tool_name']] = row['config'] 324 self.logger.info(f"[OUT] get_all_tool_configs: Found {len(configs)} tools") 325 else: 326 self.logger.info("[OUT] get_all_tool_configs: No configs found (empty result)") 327 328 return {"success": True, "data": configs} 329 330 except Exception as e: 331 self.logger.error(f"[ERROR] get_all_tool_configs: {str(e)}") 332 return { 333 "success": False, 334 "error": f"Failed to get tool configs: {str(e)}", 335 "error_code": "CONFIG_FETCH_ERROR" 336 } 337 338 # ===================================================== 339 # GUARDRAILS SETTINGS 340 # ===================================================== 341 342 async def get_guardrails_settings(self) -> Dict[str, Any]: 343 """ 344 Guardrails設定を取得 345 346 Returns: 347 Dict with keys: success, data (settings dict) 348 """ 349 self.logger.info("[IN] get_guardrails_settings") 350 351 default_settings = { 352 "prompt_shield_enabled": True, 353 "moderation_threshold": 2, 354 "pii_masking_enabled": True 355 } 356 357 try: 358 result = await self._db.select( 359 table="guardrails_settings", 360 columns=["prompt_shield_enabled", "moderation_threshold", "pii_masking_enabled"], 361 where={"id": 1}, 362 ) 363 364 # DBエラーを先にチェック 365 if not result.get("success"): 366 self.logger.error("[ERROR] get_guardrails_settings: DB error, using defaults") 367 # Guardrailsはデフォルト値を返す(フォールバック) 368 return {"success": True, "data": default_settings} 369 370 if result.get("data"): 371 settings = result["data"][0] 372 self.logger.info("[OUT] get_guardrails_settings: Found") 373 return {"success": True, "data": settings} 374 else: 375 self.logger.info("[OUT] get_guardrails_settings: No settings found, using defaults") 376 return {"success": True, "data": default_settings} 377 378 except Exception as e: 379 self.logger.error(f"[ERROR] get_guardrails_settings: {str(e)}, using defaults") 380 return {"success": True, "data": default_settings} 381 382 # ===================================================== 383 # SYSTEM SETTINGS 384 # ===================================================== 385 386 async def get_system_settings(self) -> Dict[str, Any]: 387 """ 388 システム設定を取得(アクティブな全レコード) 389 390 Returns: 391 Dict with keys: success, data (list of settings) 392 """ 393 self.logger.info("[IN] get_system_settings") 394 try: 395 query = """ 396 SELECT user_prompt, guardrail_text, is_active, priority 397 FROM system_settings 398 WHERE organization_id IS NULL AND is_active = true 399 """ 400 result = await self._db.execute_query(query) 401 402 # DBエラーを先にチェック 403 if not result.get("success"): 404 self.logger.error("[ERROR] get_system_settings: DB error") 405 return { 406 "success": False, 407 "error": result.get("error", "Database query failed"), 408 "error_code": result.get("error_code", "DATABASE_ERROR") 409 } 410 411 if result.get("data"): 412 self.logger.info(f"[OUT] get_system_settings: Found {len(result['data'])} settings") 413 return {"success": True, "data": result["data"]} 414 else: 415 self.logger.info("[OUT] get_system_settings: No settings found (empty result)") 416 return {"success": True, "data": []} 417 418 except Exception as e: 419 self.logger.error(f"[ERROR] get_system_settings: {str(e)}") 420 return { 421 "success": False, 422 "error": f"Database error: {str(e)}", 423 "error_code": "DATABASE_ERROR" 424 } 425 426 # ===================================================== 427 # MCP CONFIGURATIONS 428 # ===================================================== 429 430 async def get_mcp_configurations(self, execution_mode: Optional[str] = None) -> Dict[str, Any]: 431 """ 432 アクティブなMCPサーバー設定を取得(OAuth情報含む) 433 434 Args: 435 execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得 436 437 Returns: 438 Dict with keys: success, data (list of MCP configs) 439 各MCP configには以下を含む: 440 - name, server_url, cache_tools_list, headers, only_pod 441 - use_oauth: OAuthを使用するかどうか 442 - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等) 443 """ 444 self.logger.info("[IN] get_mcp_configurations") 445 try: 446 # CLI系モードの場合、only_pod=falseまたはNULLのレコードのみ取得 447 # oauth_providersとLEFT JOINしてOAuth情報も取得 448 if self._is_cli_mode(execution_mode): 449 query = """ 450 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 451 m.use_oauth, o.service as oauth_service, m.transport_type 452 FROM mcp_server_configurations m 453 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 454 WHERE m.is_active = TRUE 455 AND (m.only_pod = FALSE OR m.only_pod IS NULL) 456 ORDER BY m.name 457 """ 458 self.logger.info("CLI mode: Filtering out pod-only MCP servers") 459 else: 460 query = """ 461 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 462 m.use_oauth, o.service as oauth_service, m.transport_type 463 FROM mcp_server_configurations m 464 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 465 WHERE m.is_active = TRUE 466 ORDER BY m.name 467 """ 468 469 result = await self._db.execute_query(query) 470 471 # DBエラーを先にチェック 472 if not result.get("success"): 473 self.logger.error("[ERROR] get_mcp_configurations: DB error") 474 return { 475 "success": False, 476 "error": result.get("error", "Database query failed"), 477 "error_code": result.get("error_code", "DATABASE_ERROR") 478 } 479 480 if result.get("data"): 481 configurations = [] 482 for row in result["data"]: 483 # headersをdictに変換 484 headers_raw = row['headers'] 485 headers = None 486 if headers_raw: 487 if isinstance(headers_raw, str): 488 try: 489 headers = json.loads(headers_raw) 490 except Exception as e: 491 self.logger.warning(f"Failed to parse headers for '{row['name']}': {e}") 492 headers = None 493 elif isinstance(headers_raw, dict): 494 headers = headers_raw 495 496 if headers == {}: 497 headers = None 498 499 config = { 500 "name": row['name'], 501 "server_url": row['server_url'], 502 "cache_tools_list": row['cache_tools_list'], 503 "headers": headers, 504 "only_pod": row.get('only_pod', False), 505 # OAuth情報 506 "use_oauth": row.get('use_oauth', False), 507 "oauth_service": row.get('oauth_service'), # github, slack, google, office365等 508 # Transport type: 'sse', 'streamable_http', 'stdio' 509 "transport_type": row.get('transport_type', 'sse'), 510 } 511 configurations.append(config) 512 513 self.logger.info(f"[OUT] get_mcp_configurations: Found {len(configurations)} MCP servers") 514 return {"success": True, "data": configurations} 515 else: 516 self.logger.info("[OUT] get_mcp_configurations: No MCP configs found (empty result)") 517 return {"success": True, "data": []} 518 519 except Exception as e: 520 self.logger.error(f"[ERROR] get_mcp_configurations: {str(e)}") 521 return { 522 "success": False, 523 "error": f"Database error: {str(e)}", 524 "error_code": "DATABASE_ERROR" 525 }
設定データアクセスクラス
各種設定テーブルからの読み込み機能を提供:
- エージェント指示・設定
- ツール設定
- ガードレール設定
- システム設定
- MCP設定
35 def __init__(self, db: DataAccess): 36 """ 37 Args: 38 db: DataAccessインスタンス 39 """ 40 self._db = db 41 self.logger = logger
Args: db: DataAccessインスタンス
52 async def get_agent_instructions(self, agent_type: str) -> Dict[str, Any]: 53 """ 54 エージェント指示を取得 55 56 Args: 57 agent_type: エージェントタイプ ('intent', 'react', 'coordinator') 58 59 Returns: 60 Dict with keys: success, data (instructions string), error (optional) 61 """ 62 self.logger.info(f"[IN] get_agent_instructions: agent_type={agent_type}") 63 try: 64 result = await self._db.select( 65 table="agent_instructions", 66 columns=["instructions"], 67 where={"agent_type": agent_type}, 68 ) 69 70 # DBエラーを先にチェック 71 if not result.get("success"): 72 self.logger.error(f"[ERROR] get_agent_instructions: DB error for {agent_type}") 73 return { 74 "success": False, 75 "error": result.get("error", "Database query failed"), 76 "error_code": result.get("error_code", "DATABASE_ERROR") 77 } 78 79 # データが存在するかチェック 80 if result.get("data"): 81 instructions = result["data"][0]["instructions"] 82 self.logger.info(f"[OUT] get_agent_instructions: Found for {agent_type}") 83 return {"success": True, "data": instructions} 84 else: 85 self.logger.warning(f"[OUT] get_agent_instructions: Not found for {agent_type}") 86 return { 87 "success": False, 88 "error": f"No instructions found for agent type: {agent_type}", 89 "error_code": "INSTRUCTIONS_NOT_FOUND" 90 } 91 92 except Exception as e: 93 self.logger.error(f"[ERROR] get_agent_instructions: {str(e)}") 94 return { 95 "success": False, 96 "error": f"Database error: {str(e)}", 97 "error_code": "DATABASE_ERROR" 98 }
エージェント指示を取得
Args: agent_type: エージェントタイプ ('intent', 'react', 'coordinator')
Returns: Dict with keys: success, data (instructions string), error (optional)
104 async def get_all_agent_configs(self) -> Dict[str, Any]: 105 """ 106 全エージェント設定を一括取得 107 108 Returns: 109 Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional) 110 """ 111 self.logger.info("[IN] get_all_agent_configs") 112 try: 113 query = """ 114 SELECT agent_type, config, config_level 115 FROM agent_configurations 116 WHERE is_active = TRUE 117 ORDER BY agent_type, config_level 118 """ 119 120 result = await self._db.execute_query(query) 121 122 # DBエラーを先にチェック 123 if not result.get("success"): 124 self.logger.error("[ERROR] get_all_agent_configs: DB error") 125 return { 126 "success": False, 127 "error": result.get("error", "Database query failed"), 128 "error_code": result.get("error_code", "DATABASE_ERROR") 129 } 130 131 configs = {} 132 if result.get("data"): 133 for row in result["data"]: 134 agent_type = row['agent_type'] 135 config_level = row['config_level'] 136 137 if agent_type not in configs: 138 configs[agent_type] = {} 139 140 configs[agent_type][config_level] = row['config'] 141 142 self.logger.info(f"[OUT] get_all_agent_configs: Found {len(configs)} types") 143 else: 144 self.logger.info("[OUT] get_all_agent_configs: No configs found (empty result)") 145 146 return {"success": True, "data": configs} 147 148 except Exception as e: 149 self.logger.error(f"[ERROR] get_all_agent_configs: {str(e)}") 150 return { 151 "success": False, 152 "error": f"Failed to get agent configs: {str(e)}", 153 "error_code": "CONFIG_FETCH_ERROR" 154 }
全エージェント設定を一括取得
Returns: Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional)
156 async def get_agent_config( 157 self, 158 agent_type: str, 159 requested_level: str = "default" 160 ) -> Dict[str, Any]: 161 """ 162 エージェント設定を取得(優先順位: user > requested_level > default) 163 164 Args: 165 agent_type: エージェントタイプ 166 requested_level: リクエストレベル ('default', 'high_performance') 167 168 Returns: 169 Dict with keys: success, data ({config, level}), error (optional) 170 """ 171 self.logger.info(f"[IN] get_agent_config: agent_type={agent_type}, level={requested_level}") 172 try: 173 # まずユーザー設定を確認 174 user_config_query = """ 175 SELECT config FROM agent_configurations 176 WHERE config_level = 'user' AND agent_type = $1 AND is_active = TRUE 177 """ 178 result = await self._db.execute_query(user_config_query, (agent_type,)) 179 180 # DBエラーを先にチェック 181 if not result.get("success"): 182 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 183 return { 184 "success": False, 185 "error": result.get("error", "Database query failed"), 186 "error_code": result.get("error_code", "DATABASE_ERROR") 187 } 188 189 if result.get("data"): 190 self.logger.info(f"[OUT] get_agent_config: Found user config for {agent_type}") 191 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "user"}} 192 193 # 次に要求されたレベルの設定を確認 194 level_config_query = """ 195 SELECT config FROM agent_configurations 196 WHERE config_level = $1 AND agent_type = $2 AND is_active = TRUE 197 """ 198 result = await self._db.execute_query(level_config_query, (requested_level, agent_type)) 199 200 if not result.get("success"): 201 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 202 return { 203 "success": False, 204 "error": result.get("error", "Database query failed"), 205 "error_code": result.get("error_code", "DATABASE_ERROR") 206 } 207 208 if result.get("data"): 209 self.logger.info(f"[OUT] get_agent_config: Found {requested_level} config for {agent_type}") 210 return {"success": True, "data": {"config": result["data"][0]['config'], "level": requested_level}} 211 212 # 最後にデフォルト設定を確認 213 if requested_level != "default": 214 result = await self._db.execute_query(level_config_query, ("default", agent_type)) 215 216 if not result.get("success"): 217 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 218 return { 219 "success": False, 220 "error": result.get("error", "Database query failed"), 221 "error_code": result.get("error_code", "DATABASE_ERROR") 222 } 223 224 if result.get("data"): 225 self.logger.info(f"[OUT] get_agent_config: Found default config for {agent_type}") 226 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "default"}} 227 228 self.logger.warning(f"[OUT] get_agent_config: No config found for {agent_type}") 229 return { 230 "success": False, 231 "error": f"No configuration found for agent: {agent_type}", 232 "error_code": "CONFIG_NOT_FOUND" 233 } 234 235 except Exception as e: 236 self.logger.error(f"[ERROR] get_agent_config: {str(e)}") 237 return { 238 "success": False, 239 "error": str(e), 240 "error_code": "DATABASE_ERROR" 241 }
エージェント設定を取得(優先順位: user > requested_level > default)
Args: agent_type: エージェントタイプ requested_level: リクエストレベル ('default', 'high_performance')
Returns: Dict with keys: success, data ({config, level}), error (optional)
247 async def get_tool_config(self, tool_name: str) -> Dict[str, Any]: 248 """ 249 ツール設定を取得 250 251 Args: 252 tool_name: ツール名 253 254 Returns: 255 Dict with keys: success, data ({config, tool_name}), error (optional) 256 """ 257 self.logger.info(f"[IN] get_tool_config: tool_name={tool_name}") 258 try: 259 result = await self._db.select( 260 table="tool_configurations", 261 columns=["config"], 262 where={"tool_name": tool_name, "is_active": True}, 263 ) 264 265 # DBエラーを先にチェック 266 if not result.get("success"): 267 self.logger.error(f"[ERROR] get_tool_config: DB error for {tool_name}") 268 return { 269 "success": False, 270 "error": result.get("error", "Database query failed"), 271 "error_code": result.get("error_code", "DATABASE_ERROR") 272 } 273 274 if result.get("data"): 275 self.logger.info(f"[OUT] get_tool_config: Found for {tool_name}") 276 return {"success": True, "data": {"config": result["data"][0]['config'], "tool_name": tool_name}} 277 278 self.logger.warning(f"[OUT] get_tool_config: Not found for {tool_name}") 279 return { 280 "success": False, 281 "error": f"No configuration found for tool: {tool_name}", 282 "error_code": "CONFIG_NOT_FOUND" 283 } 284 285 except Exception as e: 286 self.logger.error(f"[ERROR] get_tool_config: {str(e)}") 287 return { 288 "success": False, 289 "error": str(e), 290 "error_code": "DATABASE_ERROR" 291 }
ツール設定を取得
Args: tool_name: ツール名
Returns: Dict with keys: success, data ({config, tool_name}), error (optional)
293 async def get_all_tool_configs(self) -> Dict[str, Any]: 294 """ 295 全ツール設定を一括取得 296 297 Returns: 298 Dict with keys: success, data ({tool_name: config}), error (optional) 299 """ 300 self.logger.info("[IN] get_all_tool_configs") 301 try: 302 query = """ 303 SELECT tool_name, config 304 FROM tool_configurations 305 WHERE is_active = TRUE 306 ORDER BY tool_name 307 """ 308 309 result = await self._db.execute_query(query) 310 311 # DBエラーを先にチェック 312 if not result.get("success"): 313 self.logger.error("[ERROR] get_all_tool_configs: DB error") 314 return { 315 "success": False, 316 "error": result.get("error", "Database query failed"), 317 "error_code": result.get("error_code", "DATABASE_ERROR") 318 } 319 320 configs = {} 321 if result.get("data"): 322 for row in result["data"]: 323 configs[row['tool_name']] = row['config'] 324 self.logger.info(f"[OUT] get_all_tool_configs: Found {len(configs)} tools") 325 else: 326 self.logger.info("[OUT] get_all_tool_configs: No configs found (empty result)") 327 328 return {"success": True, "data": configs} 329 330 except Exception as e: 331 self.logger.error(f"[ERROR] get_all_tool_configs: {str(e)}") 332 return { 333 "success": False, 334 "error": f"Failed to get tool configs: {str(e)}", 335 "error_code": "CONFIG_FETCH_ERROR" 336 }
全ツール設定を一括取得
Returns: Dict with keys: success, data ({tool_name: config}), error (optional)
342 async def get_guardrails_settings(self) -> Dict[str, Any]: 343 """ 344 Guardrails設定を取得 345 346 Returns: 347 Dict with keys: success, data (settings dict) 348 """ 349 self.logger.info("[IN] get_guardrails_settings") 350 351 default_settings = { 352 "prompt_shield_enabled": True, 353 "moderation_threshold": 2, 354 "pii_masking_enabled": True 355 } 356 357 try: 358 result = await self._db.select( 359 table="guardrails_settings", 360 columns=["prompt_shield_enabled", "moderation_threshold", "pii_masking_enabled"], 361 where={"id": 1}, 362 ) 363 364 # DBエラーを先にチェック 365 if not result.get("success"): 366 self.logger.error("[ERROR] get_guardrails_settings: DB error, using defaults") 367 # Guardrailsはデフォルト値を返す(フォールバック) 368 return {"success": True, "data": default_settings} 369 370 if result.get("data"): 371 settings = result["data"][0] 372 self.logger.info("[OUT] get_guardrails_settings: Found") 373 return {"success": True, "data": settings} 374 else: 375 self.logger.info("[OUT] get_guardrails_settings: No settings found, using defaults") 376 return {"success": True, "data": default_settings} 377 378 except Exception as e: 379 self.logger.error(f"[ERROR] get_guardrails_settings: {str(e)}, using defaults") 380 return {"success": True, "data": default_settings}
Guardrails設定を取得
Returns: Dict with keys: success, data (settings dict)
386 async def get_system_settings(self) -> Dict[str, Any]: 387 """ 388 システム設定を取得(アクティブな全レコード) 389 390 Returns: 391 Dict with keys: success, data (list of settings) 392 """ 393 self.logger.info("[IN] get_system_settings") 394 try: 395 query = """ 396 SELECT user_prompt, guardrail_text, is_active, priority 397 FROM system_settings 398 WHERE organization_id IS NULL AND is_active = true 399 """ 400 result = await self._db.execute_query(query) 401 402 # DBエラーを先にチェック 403 if not result.get("success"): 404 self.logger.error("[ERROR] get_system_settings: DB error") 405 return { 406 "success": False, 407 "error": result.get("error", "Database query failed"), 408 "error_code": result.get("error_code", "DATABASE_ERROR") 409 } 410 411 if result.get("data"): 412 self.logger.info(f"[OUT] get_system_settings: Found {len(result['data'])} settings") 413 return {"success": True, "data": result["data"]} 414 else: 415 self.logger.info("[OUT] get_system_settings: No settings found (empty result)") 416 return {"success": True, "data": []} 417 418 except Exception as e: 419 self.logger.error(f"[ERROR] get_system_settings: {str(e)}") 420 return { 421 "success": False, 422 "error": f"Database error: {str(e)}", 423 "error_code": "DATABASE_ERROR" 424 }
システム設定を取得(アクティブな全レコード)
Returns: Dict with keys: success, data (list of settings)
430 async def get_mcp_configurations(self, execution_mode: Optional[str] = None) -> Dict[str, Any]: 431 """ 432 アクティブなMCPサーバー設定を取得(OAuth情報含む) 433 434 Args: 435 execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得 436 437 Returns: 438 Dict with keys: success, data (list of MCP configs) 439 各MCP configには以下を含む: 440 - name, server_url, cache_tools_list, headers, only_pod 441 - use_oauth: OAuthを使用するかどうか 442 - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等) 443 """ 444 self.logger.info("[IN] get_mcp_configurations") 445 try: 446 # CLI系モードの場合、only_pod=falseまたはNULLのレコードのみ取得 447 # oauth_providersとLEFT JOINしてOAuth情報も取得 448 if self._is_cli_mode(execution_mode): 449 query = """ 450 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 451 m.use_oauth, o.service as oauth_service, m.transport_type 452 FROM mcp_server_configurations m 453 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 454 WHERE m.is_active = TRUE 455 AND (m.only_pod = FALSE OR m.only_pod IS NULL) 456 ORDER BY m.name 457 """ 458 self.logger.info("CLI mode: Filtering out pod-only MCP servers") 459 else: 460 query = """ 461 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 462 m.use_oauth, o.service as oauth_service, m.transport_type 463 FROM mcp_server_configurations m 464 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 465 WHERE m.is_active = TRUE 466 ORDER BY m.name 467 """ 468 469 result = await self._db.execute_query(query) 470 471 # DBエラーを先にチェック 472 if not result.get("success"): 473 self.logger.error("[ERROR] get_mcp_configurations: DB error") 474 return { 475 "success": False, 476 "error": result.get("error", "Database query failed"), 477 "error_code": result.get("error_code", "DATABASE_ERROR") 478 } 479 480 if result.get("data"): 481 configurations = [] 482 for row in result["data"]: 483 # headersをdictに変換 484 headers_raw = row['headers'] 485 headers = None 486 if headers_raw: 487 if isinstance(headers_raw, str): 488 try: 489 headers = json.loads(headers_raw) 490 except Exception as e: 491 self.logger.warning(f"Failed to parse headers for '{row['name']}': {e}") 492 headers = None 493 elif isinstance(headers_raw, dict): 494 headers = headers_raw 495 496 if headers == {}: 497 headers = None 498 499 config = { 500 "name": row['name'], 501 "server_url": row['server_url'], 502 "cache_tools_list": row['cache_tools_list'], 503 "headers": headers, 504 "only_pod": row.get('only_pod', False), 505 # OAuth情報 506 "use_oauth": row.get('use_oauth', False), 507 "oauth_service": row.get('oauth_service'), # github, slack, google, office365等 508 # Transport type: 'sse', 'streamable_http', 'stdio' 509 "transport_type": row.get('transport_type', 'sse'), 510 } 511 configurations.append(config) 512 513 self.logger.info(f"[OUT] get_mcp_configurations: Found {len(configurations)} MCP servers") 514 return {"success": True, "data": configurations} 515 else: 516 self.logger.info("[OUT] get_mcp_configurations: No MCP configs found (empty result)") 517 return {"success": True, "data": []} 518 519 except Exception as e: 520 self.logger.error(f"[ERROR] get_mcp_configurations: {str(e)}") 521 return { 522 "success": False, 523 "error": f"Database error: {str(e)}", 524 "error_code": "DATABASE_ERROR" 525 }
アクティブなMCPサーバー設定を取得(OAuth情報含む)
Args: execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得
Returns: Dict with keys: success, data (list of MCP configs) 各MCP configには以下を含む: - name, server_url, cache_tools_list, headers, only_pod - use_oauth: OAuthを使用するかどうか - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等)
151class EmbeddingGenerator: 152 """ 153 Azure OpenAI Embedding生成クラス 154 155 テキストをベクトル表現に変換します。 156 LRUキャッシュとレートリミット対応により、効率的かつ安定的にEmbeddingを生成します。 157 158 Features: 159 - 非同期バッチEmbedding生成(batch_generate) 160 - LRUメモリキャッシュによる重複リクエスト削減 161 - 指数バックオフ+ジッターによるレートリミット対応 162 - Azure OpenAI text-embedding-ada-002モデル対応 163 - キャッシュ統計とキャッシュクリア機能 164 165 Example: 166 >>> from agenticstar_platform import EmbeddingGenerator, EmbeddingConfig 167 >>> 168 >>> # 設定を作成 169 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 170 >>> 171 >>> # クライアント初期化 172 >>> generator = EmbeddingGenerator(config) 173 >>> 174 >>> # 単一テキストのEmbedding生成 175 >>> embedding = await generator.generate("Azure OpenAIの使い方を教えてください") 176 >>> print(f"Embedding dimensions: {len(embedding)}") # 1536 177 >>> 178 >>> # バッチEmbedding生成(効率的) 179 >>> texts = ["テキスト1", "テキスト2", "テキスト3"] 180 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 181 >>> print(f"Generated {len(embeddings)} embeddings") 182 >>> 183 >>> # キャッシュ統計確認 184 >>> stats = generator.get_cache_stats() 185 >>> print(f"Cache utilization: {stats['cache_utilization']:.1%}") 186 """ 187 188 def __init__(self, config: EmbeddingConfig): 189 """ 190 Initialize Embedding Generator 191 192 Args: 193 config: EmbeddingConfig instance 194 """ 195 self.client = AsyncAzureOpenAI( 196 azure_endpoint=config.base_url, 197 api_key=config.api_key, 198 api_version=config.api_version, 199 ) 200 self.deployment_name = config.model 201 self.model_name = config.model # For model_version tracking 202 self.dimensions = config.dimensions 203 204 # リトライ設定 205 self._max_retries = config.max_retries 206 self._base_delay = config.base_delay 207 self._max_delay = config.max_delay 208 209 # LRU Cache 210 self._memory_cache: Dict[str, List[float]] = {} 211 self._cache_order: List[str] = [] # LRU tracking 212 self._max_cache_size = config.max_cache_size 213 214 logger.info( 215 f"EmbeddingGenerator initialized - model={config.model}, " 216 f"cache_size={config.max_cache_size}, max_retries={config.max_retries}" 217 ) 218 219 async def generate(self, text: str) -> List[float]: 220 """ 221 単一テキストのEmbeddingを生成(キャッシュ対応) 222 223 同一テキストは自動的にキャッシュからEmbeddingを取得します。 224 レートリミットに達した場合は自動リトライを行います。 225 226 Args: 227 text: Embedding対象のテキスト 228 229 Returns: 230 List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元) 231 232 Raises: 233 RateLimitExceededError: レートリミット超過でリトライ上限到達時 234 EmbeddingError: その他のAPIエラー 235 236 Example: 237 >>> # 単一テキストのEmbedding生成 238 >>> embedding = await generator.generate("質問に回答する方法") 239 >>> print(f"Dimensions: {len(embedding)}") # 1536 240 >>> 241 >>> # キャッシュ済みの場合は即座に返る 242 >>> embedding_cached = await generator.generate("質問に回答する方法") 243 """ 244 # Cache key 245 cache_key = self._make_cache_key(text) 246 247 # Cache check 248 if cache_key in self._memory_cache: 249 logger.debug(f"Cache hit for key={cache_key[:16]}...") 250 return self._memory_cache[cache_key] 251 252 # API call with retry 253 last_error = None 254 for attempt in range(self._max_retries): 255 try: 256 logger.debug(f"Generating embedding - text_len={len(text)} chars, attempt={attempt + 1}") 257 response = await self.client.embeddings.create( 258 input=text, model=self.deployment_name 259 ) 260 embedding = response.data[0].embedding 261 262 # Cache update 263 self._update_memory_cache(cache_key, embedding) 264 265 logger.debug(f"Embedding generated - dims={len(embedding)}") 266 return embedding 267 268 except Exception as e: 269 last_error = e 270 if "429" in str(e) or "rate" in str(e).lower(): 271 delay = self._calculate_backoff_delay(attempt) 272 logger.warning( 273 f"Rate limit hit (attempt {attempt + 1}/{self._max_retries}), " 274 f"retrying after {delay:.1f}s..." 275 ) 276 await asyncio.sleep(delay) 277 continue 278 else: 279 logger.error(f"Embedding generation failed: {e}") 280 raise EmbeddingError(f"Embedding generation failed: {e}") from e 281 282 # リトライ上限到達 283 raise RateLimitExceededError( 284 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 285 ) 286 287 async def batch_generate( 288 self, texts: List[str], batch_size: int = 16 289 ) -> List[List[float]]: 290 """ 291 複数テキストのEmbeddingを一括生成(バッチ処理) 292 293 大量のテキストを効率的にEmbedding化します。 294 キャッシュ済みのテキストはスキップし、APIコール数を削減します。 295 296 Args: 297 texts: Embedding対象のテキストリスト 298 batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる 299 300 Returns: 301 List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序) 302 303 Raises: 304 RateLimitExceededError: レートリミット超過でリトライ上限到達時 305 EmbeddingError: その他のAPIエラー 306 307 Example: 308 >>> # バッチEmbedding生成 309 >>> texts = [ 310 ... "ドキュメント1の内容", 311 ... "ドキュメント2の内容", 312 ... "ドキュメント3の内容" 313 ... ] 314 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 315 >>> print(f"Generated {len(embeddings)} embeddings") 316 >>> 317 >>> # 各Embeddingを確認 318 >>> for i, emb in enumerate(embeddings): 319 ... print(f"Text {i}: {len(emb)} dimensions") 320 """ 321 logger.info( 322 f"Batch embedding generation - texts={len(texts)}, " 323 f"batch_size={batch_size}" 324 ) 325 326 # Check cache 327 embeddings: List[Optional[List[float]]] = [] 328 uncached_texts: List[str] = [] 329 uncached_indices: List[int] = [] 330 331 for i, text in enumerate(texts): 332 cache_key = self._make_cache_key(text) 333 if cache_key in self._memory_cache: 334 embeddings.append(self._memory_cache[cache_key]) 335 else: 336 uncached_texts.append(text) 337 uncached_indices.append(i) 338 embeddings.append(None) # Placeholder 339 340 logger.debug( 341 f"Cache hits: {len(texts) - len(uncached_texts)}/{len(texts)}" 342 ) 343 344 # Batch process uncached texts 345 if uncached_texts: 346 i = 0 347 while i < len(uncached_texts): 348 batch = uncached_texts[i : i + batch_size] 349 batch_indices = uncached_indices[i : i + batch_size] 350 351 last_error = None 352 for attempt in range(self._max_retries): 353 try: 354 response = await self.client.embeddings.create( 355 input=batch, model=self.deployment_name 356 ) 357 358 # Map back to original indices 359 for j, embedding_obj in enumerate(response.data): 360 original_index = batch_indices[j] 361 embedding = embedding_obj.embedding 362 embeddings[original_index] = embedding 363 364 # Cache update 365 cache_key = self._make_cache_key(batch[j]) 366 self._update_memory_cache(cache_key, embedding) 367 368 # Success - move to next batch 369 break 370 371 except Exception as e: 372 last_error = e 373 if "429" in str(e) or "rate" in str(e).lower(): 374 delay = self._calculate_backoff_delay(attempt) 375 logger.warning( 376 f"Rate limit hit in batch (attempt {attempt + 1}/{self._max_retries}), " 377 f"retrying after {delay:.1f}s..." 378 ) 379 await asyncio.sleep(delay) 380 continue 381 else: 382 logger.error(f"Batch embedding failed: {e}") 383 raise EmbeddingError(f"Batch embedding failed: {e}") from e 384 else: 385 # リトライ上限到達 386 raise RateLimitExceededError( 387 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 388 ) 389 390 i += batch_size 391 392 return embeddings # type: ignore 393 394 def _calculate_backoff_delay(self, attempt: int) -> float: 395 """ 396 指数バックオフ + ジッターでディレイを計算 397 398 Args: 399 attempt: 現在のリトライ回数 (0-indexed) 400 401 Returns: 402 待機時間(秒) 403 """ 404 # 指数バックオフ: base_delay * 2^attempt 405 delay = self._base_delay * (2 ** attempt) 406 # 上限適用 407 delay = min(delay, self._max_delay) 408 # ジッター追加 (±25%) 409 jitter = delay * 0.25 * (2 * random.random() - 1) 410 return delay + jitter 411 412 def _make_cache_key(self, text: str) -> str: 413 """ 414 Generate cache key (model_name + text_hash) 415 416 Args: 417 text: Input text 418 419 Returns: 420 Cache key string 421 """ 422 text_hash = hashlib.sha256(text.encode()).hexdigest()[:16] 423 return f"{self.model_name}:{text_hash}" 424 425 def _update_memory_cache(self, key: str, embedding: List[float]) -> None: 426 """ 427 Update LRU cache 428 429 Args: 430 key: Cache key 431 embedding: Embedding vector 432 """ 433 if key in self._memory_cache: 434 # Update order for existing key 435 self._cache_order.remove(key) 436 elif len(self._memory_cache) >= self._max_cache_size: 437 # Evict oldest entry 438 oldest_key = self._cache_order.pop(0) 439 del self._memory_cache[oldest_key] 440 logger.debug(f"Cache evicted - key={oldest_key[:16]}...") 441 442 self._memory_cache[key] = embedding 443 self._cache_order.append(key) 444 445 def get_cache_stats(self) -> Dict[str, Any]: 446 """ 447 キャッシュ統計情報を取得 448 449 キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。 450 451 Returns: 452 Dict[str, Any]: キャッシュ統計情報 453 - cache_size: 現在のキャッシュエントリ数 454 - max_cache_size: 最大キャッシュサイズ 455 - cache_utilization: キャッシュ使用率(0.0〜1.0) 456 - model_name: 使用中のモデル名 457 458 Example: 459 >>> stats = generator.get_cache_stats() 460 >>> print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") 461 >>> print(f"Utilization: {stats['cache_utilization']:.1%}") 462 >>> if stats['cache_utilization'] > 0.9: 463 ... generator.clear_cache() 464 """ 465 return { 466 "cache_size": len(self._memory_cache), 467 "max_cache_size": self._max_cache_size, 468 "cache_utilization": len(self._memory_cache) / self._max_cache_size if self._max_cache_size > 0 else 0, 469 "model_name": self.model_name, 470 } 471 472 def clear_cache(self) -> None: 473 """ 474 キャッシュを全クリア 475 476 メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。 477 478 Example: 479 >>> # キャッシュ統計確認 480 >>> stats = generator.get_cache_stats() 481 >>> print(f"Before: {stats['cache_size']} entries") 482 >>> 483 >>> # キャッシュクリア 484 >>> generator.clear_cache() 485 >>> print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0 486 """ 487 self._memory_cache.clear() 488 self._cache_order.clear() 489 logger.info("Embedding cache cleared")
Azure OpenAI Embedding生成クラス
テキストをベクトル表現に変換します。 LRUキャッシュとレートリミット対応により、効率的かつ安定的にEmbeddingを生成します。
Features: - 非同期バッチEmbedding生成(batch_generate) - LRUメモリキャッシュによる重複リクエスト削減 - 指数バックオフ+ジッターによるレートリミット対応 - Azure OpenAI text-embedding-ada-002モデル対応 - キャッシュ統計とキャッシュクリア機能
Example:
from agenticstar_platform import EmbeddingGenerator, EmbeddingConfig
設定を作成
config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding")
クライアント初期化
generator = EmbeddingGenerator(config)
単一テキストのEmbedding生成
embedding = await generator.generate("Azure OpenAIの使い方を教えてください") print(f"Embedding dimensions: {len(embedding)}") # 1536
バッチEmbedding生成(効率的)
texts = ["テキスト1", "テキスト2", "テキスト3"] embeddings = await generator.batch_generate(texts, batch_size=16) print(f"Generated {len(embeddings)} embeddings")
キャッシュ統計確認
stats = generator.get_cache_stats() print(f"Cache utilization: {stats['cache_utilization']:.1%}")
188 def __init__(self, config: EmbeddingConfig): 189 """ 190 Initialize Embedding Generator 191 192 Args: 193 config: EmbeddingConfig instance 194 """ 195 self.client = AsyncAzureOpenAI( 196 azure_endpoint=config.base_url, 197 api_key=config.api_key, 198 api_version=config.api_version, 199 ) 200 self.deployment_name = config.model 201 self.model_name = config.model # For model_version tracking 202 self.dimensions = config.dimensions 203 204 # リトライ設定 205 self._max_retries = config.max_retries 206 self._base_delay = config.base_delay 207 self._max_delay = config.max_delay 208 209 # LRU Cache 210 self._memory_cache: Dict[str, List[float]] = {} 211 self._cache_order: List[str] = [] # LRU tracking 212 self._max_cache_size = config.max_cache_size 213 214 logger.info( 215 f"EmbeddingGenerator initialized - model={config.model}, " 216 f"cache_size={config.max_cache_size}, max_retries={config.max_retries}" 217 )
Initialize Embedding Generator
Args: config: EmbeddingConfig instance
219 async def generate(self, text: str) -> List[float]: 220 """ 221 単一テキストのEmbeddingを生成(キャッシュ対応) 222 223 同一テキストは自動的にキャッシュからEmbeddingを取得します。 224 レートリミットに達した場合は自動リトライを行います。 225 226 Args: 227 text: Embedding対象のテキスト 228 229 Returns: 230 List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元) 231 232 Raises: 233 RateLimitExceededError: レートリミット超過でリトライ上限到達時 234 EmbeddingError: その他のAPIエラー 235 236 Example: 237 >>> # 単一テキストのEmbedding生成 238 >>> embedding = await generator.generate("質問に回答する方法") 239 >>> print(f"Dimensions: {len(embedding)}") # 1536 240 >>> 241 >>> # キャッシュ済みの場合は即座に返る 242 >>> embedding_cached = await generator.generate("質問に回答する方法") 243 """ 244 # Cache key 245 cache_key = self._make_cache_key(text) 246 247 # Cache check 248 if cache_key in self._memory_cache: 249 logger.debug(f"Cache hit for key={cache_key[:16]}...") 250 return self._memory_cache[cache_key] 251 252 # API call with retry 253 last_error = None 254 for attempt in range(self._max_retries): 255 try: 256 logger.debug(f"Generating embedding - text_len={len(text)} chars, attempt={attempt + 1}") 257 response = await self.client.embeddings.create( 258 input=text, model=self.deployment_name 259 ) 260 embedding = response.data[0].embedding 261 262 # Cache update 263 self._update_memory_cache(cache_key, embedding) 264 265 logger.debug(f"Embedding generated - dims={len(embedding)}") 266 return embedding 267 268 except Exception as e: 269 last_error = e 270 if "429" in str(e) or "rate" in str(e).lower(): 271 delay = self._calculate_backoff_delay(attempt) 272 logger.warning( 273 f"Rate limit hit (attempt {attempt + 1}/{self._max_retries}), " 274 f"retrying after {delay:.1f}s..." 275 ) 276 await asyncio.sleep(delay) 277 continue 278 else: 279 logger.error(f"Embedding generation failed: {e}") 280 raise EmbeddingError(f"Embedding generation failed: {e}") from e 281 282 # リトライ上限到達 283 raise RateLimitExceededError( 284 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 285 )
単一テキストのEmbeddingを生成(キャッシュ対応)
同一テキストは自動的にキャッシュからEmbeddingを取得します。 レートリミットに達した場合は自動リトライを行います。
Args: text: Embedding対象のテキスト
Returns: List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元)
Raises: RateLimitExceededError: レートリミット超過でリトライ上限到達時 EmbeddingError: その他のAPIエラー
Example:
単一テキストのEmbedding生成
embedding = await generator.generate("質問に回答する方法") print(f"Dimensions: {len(embedding)}") # 1536
キャッシュ済みの場合は即座に返る
embedding_cached = await generator.generate("質問に回答する方法")
287 async def batch_generate( 288 self, texts: List[str], batch_size: int = 16 289 ) -> List[List[float]]: 290 """ 291 複数テキストのEmbeddingを一括生成(バッチ処理) 292 293 大量のテキストを効率的にEmbedding化します。 294 キャッシュ済みのテキストはスキップし、APIコール数を削減します。 295 296 Args: 297 texts: Embedding対象のテキストリスト 298 batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる 299 300 Returns: 301 List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序) 302 303 Raises: 304 RateLimitExceededError: レートリミット超過でリトライ上限到達時 305 EmbeddingError: その他のAPIエラー 306 307 Example: 308 >>> # バッチEmbedding生成 309 >>> texts = [ 310 ... "ドキュメント1の内容", 311 ... "ドキュメント2の内容", 312 ... "ドキュメント3の内容" 313 ... ] 314 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 315 >>> print(f"Generated {len(embeddings)} embeddings") 316 >>> 317 >>> # 各Embeddingを確認 318 >>> for i, emb in enumerate(embeddings): 319 ... print(f"Text {i}: {len(emb)} dimensions") 320 """ 321 logger.info( 322 f"Batch embedding generation - texts={len(texts)}, " 323 f"batch_size={batch_size}" 324 ) 325 326 # Check cache 327 embeddings: List[Optional[List[float]]] = [] 328 uncached_texts: List[str] = [] 329 uncached_indices: List[int] = [] 330 331 for i, text in enumerate(texts): 332 cache_key = self._make_cache_key(text) 333 if cache_key in self._memory_cache: 334 embeddings.append(self._memory_cache[cache_key]) 335 else: 336 uncached_texts.append(text) 337 uncached_indices.append(i) 338 embeddings.append(None) # Placeholder 339 340 logger.debug( 341 f"Cache hits: {len(texts) - len(uncached_texts)}/{len(texts)}" 342 ) 343 344 # Batch process uncached texts 345 if uncached_texts: 346 i = 0 347 while i < len(uncached_texts): 348 batch = uncached_texts[i : i + batch_size] 349 batch_indices = uncached_indices[i : i + batch_size] 350 351 last_error = None 352 for attempt in range(self._max_retries): 353 try: 354 response = await self.client.embeddings.create( 355 input=batch, model=self.deployment_name 356 ) 357 358 # Map back to original indices 359 for j, embedding_obj in enumerate(response.data): 360 original_index = batch_indices[j] 361 embedding = embedding_obj.embedding 362 embeddings[original_index] = embedding 363 364 # Cache update 365 cache_key = self._make_cache_key(batch[j]) 366 self._update_memory_cache(cache_key, embedding) 367 368 # Success - move to next batch 369 break 370 371 except Exception as e: 372 last_error = e 373 if "429" in str(e) or "rate" in str(e).lower(): 374 delay = self._calculate_backoff_delay(attempt) 375 logger.warning( 376 f"Rate limit hit in batch (attempt {attempt + 1}/{self._max_retries}), " 377 f"retrying after {delay:.1f}s..." 378 ) 379 await asyncio.sleep(delay) 380 continue 381 else: 382 logger.error(f"Batch embedding failed: {e}") 383 raise EmbeddingError(f"Batch embedding failed: {e}") from e 384 else: 385 # リトライ上限到達 386 raise RateLimitExceededError( 387 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 388 ) 389 390 i += batch_size 391 392 return embeddings # type: ignore
複数テキストのEmbeddingを一括生成(バッチ処理)
大量のテキストを効率的にEmbedding化します。 キャッシュ済みのテキストはスキップし、APIコール数を削減します。
Args: texts: Embedding対象のテキストリスト batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる
Returns: List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序)
Raises: RateLimitExceededError: レートリミット超過でリトライ上限到達時 EmbeddingError: その他のAPIエラー
Example:
バッチEmbedding生成
texts = [ ... "ドキュメント1の内容", ... "ドキュメント2の内容", ... "ドキュメント3の内容" ... ] embeddings = await generator.batch_generate(texts, batch_size=16) print(f"Generated {len(embeddings)} embeddings")
各Embeddingを確認
for i, emb in enumerate(embeddings): ... print(f"Text {i}: {len(emb)} dimensions")
445 def get_cache_stats(self) -> Dict[str, Any]: 446 """ 447 キャッシュ統計情報を取得 448 449 キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。 450 451 Returns: 452 Dict[str, Any]: キャッシュ統計情報 453 - cache_size: 現在のキャッシュエントリ数 454 - max_cache_size: 最大キャッシュサイズ 455 - cache_utilization: キャッシュ使用率(0.0〜1.0) 456 - model_name: 使用中のモデル名 457 458 Example: 459 >>> stats = generator.get_cache_stats() 460 >>> print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") 461 >>> print(f"Utilization: {stats['cache_utilization']:.1%}") 462 >>> if stats['cache_utilization'] > 0.9: 463 ... generator.clear_cache() 464 """ 465 return { 466 "cache_size": len(self._memory_cache), 467 "max_cache_size": self._max_cache_size, 468 "cache_utilization": len(self._memory_cache) / self._max_cache_size if self._max_cache_size > 0 else 0, 469 "model_name": self.model_name, 470 }
キャッシュ統計情報を取得
キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。
Returns: Dict[str, Any]: キャッシュ統計情報 - cache_size: 現在のキャッシュエントリ数 - max_cache_size: 最大キャッシュサイズ - cache_utilization: キャッシュ使用率(0.0〜1.0) - model_name: 使用中のモデル名
Example:
stats = generator.get_cache_stats() print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") print(f"Utilization: {stats['cache_utilization']:.1%}") if stats['cache_utilization'] > 0.9: ... generator.clear_cache()
472 def clear_cache(self) -> None: 473 """ 474 キャッシュを全クリア 475 476 メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。 477 478 Example: 479 >>> # キャッシュ統計確認 480 >>> stats = generator.get_cache_stats() 481 >>> print(f"Before: {stats['cache_size']} entries") 482 >>> 483 >>> # キャッシュクリア 484 >>> generator.clear_cache() 485 >>> print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0 486 """ 487 self._memory_cache.clear() 488 self._cache_order.clear() 489 logger.info("Embedding cache cleared")
キャッシュを全クリア
メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。
Example:
キャッシュ統計確認
stats = generator.get_cache_stats() print(f"Before: {stats['cache_size']} entries")
キャッシュクリア
generator.clear_cache() print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0
38@dataclass 39class EmbeddingConfig: 40 """Embedding設定 41 42 Azure OpenAI Embedding APIへの接続設定を管理します。 43 44 Attributes: 45 base_url: Azure OpenAI エンドポイントURL(例: "https://your-resource.openai.azure.com/") 46 api_key: APIキー 47 model: デプロイメント名/モデル名(例: "text-embedding-3-small") 48 api_version: APIバージョン(デフォルト: "2024-02-15-preview") 49 dimensions: Embedding次元数(text-embedding-3-smallは1536) 50 max_cache_size: キャッシュ最大サイズ(デフォルト: 10000) 51 max_retries: リトライ回数(デフォルト: 5) 52 base_delay: 基本待機時間秒(デフォルト: 1.0) 53 max_delay: 最大待機時間秒(デフォルト: 60.0) 54 55 Example: 56 >>> # 標準的な作成方法 57 >>> config = EmbeddingConfig( 58 ... base_url="https://your-resource.openai.azure.com/", 59 ... api_key="your-api-key", 60 ... model="text-embedding-3-small", 61 ... dimensions=1536, 62 ... ) 63 64 >>> # 辞書から作成 65 >>> config = EmbeddingConfig.from_dict({ 66 ... "base_url": "https://your-resource.openai.azure.com/", 67 ... "api_key": "your-api-key", 68 ... "model": "text-embedding-3-small" 69 ... }) 70 71 >>> # TOMLファイルから作成 72 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 73 """ 74 # 標準パラメータ名(業界標準に準拠) 75 base_url: str 76 api_key: str 77 model: str = "text-embedding-ada-002" 78 api_version: str = "2024-02-15-preview" 79 max_cache_size: int = 10000 80 dimensions: int = 1536 # text-embedding-ada-002のデフォルト 81 # リトライ設定 82 max_retries: int = 5 83 base_delay: float = 1.0 84 max_delay: float = 60.0 85 86 @classmethod 87 def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingConfig": 88 """辞書からEmbeddingConfigを作成 89 90 Args: 91 data: 設定辞書(base_url, api_key, model 等) 92 93 Returns: 94 EmbeddingConfig instance 95 96 Example: 97 >>> config = EmbeddingConfig.from_dict({ 98 ... "base_url": "https://your-resource.openai.azure.com/", 99 ... "api_key": "your-api-key", 100 ... "model": "text-embedding-3-small", 101 ... "dimensions": 1536 102 ... }) 103 """ 104 base_url = data.get("base_url") 105 if not base_url: 106 raise ValueError("'base_url' is required") 107 108 model = data.get("model", "text-embedding-ada-002") 109 110 return cls( 111 base_url=base_url, 112 api_key=data["api_key"], 113 model=model, 114 api_version=data.get("api_version", "2024-02-15-preview"), 115 max_cache_size=data.get("max_cache_size", 10000), 116 dimensions=data.get("dimensions", 1536), 117 max_retries=data.get("max_retries", 5), 118 base_delay=data.get("base_delay", 1.0), 119 max_delay=data.get("max_delay", 60.0), 120 ) 121 122 @classmethod 123 def from_toml(cls, toml_path: str, section: str = "rag.embedding") -> "EmbeddingConfig": 124 """TOMLファイルからEmbeddingConfigを作成 125 126 Args: 127 toml_path: TOMLファイルパス 128 section: セクション名(ドット区切りでネスト対応) 129 130 Returns: 131 EmbeddingConfig instance 132 133 Example: 134 >>> # config.toml の [rag.embedding] セクションを読み込み 135 >>> config = EmbeddingConfig.from_toml("config.toml") 136 """ 137 if not TOML_AVAILABLE: 138 raise ImportError("toml library is required: pip install toml") 139 140 with open(toml_path, "r") as f: 141 toml_data = toml.load(f) 142 143 # ネストされたセクションをドット区切りで辿る 144 data = toml_data 145 for key in section.split("."): 146 data = data[key] 147 148 return cls.from_dict(data)
Embedding設定
Azure OpenAI Embedding APIへの接続設定を管理します。
Attributes: base_url: Azure OpenAI エンドポイントURL(例: "https://your-resource.openai.azure.com/") api_key: APIキー model: デプロイメント名/モデル名(例: "text-embedding-3-small") api_version: APIバージョン(デフォルト: "2024-02-15-preview") dimensions: Embedding次元数(text-embedding-3-smallは1536) max_cache_size: キャッシュ最大サイズ(デフォルト: 10000) max_retries: リトライ回数(デフォルト: 5) base_delay: 基本待機時間秒(デフォルト: 1.0) max_delay: 最大待機時間秒(デフォルト: 60.0)
Example:
標準的な作成方法
config = EmbeddingConfig( ... base_url="https://your-resource.openai.azure.com/", ... api_key="your-api-key", ... model="text-embedding-3-small", ... dimensions=1536, ... )
>>> # 辞書から作成 >>> config = EmbeddingConfig.from_dict({ ... "base_url": "https://your-resource.openai.azure.com/", ... "api_key": "your-api-key", ... "model": "text-embedding-3-small" ... }) >>> # TOMLファイルから作成 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding")
86 @classmethod 87 def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingConfig": 88 """辞書からEmbeddingConfigを作成 89 90 Args: 91 data: 設定辞書(base_url, api_key, model 等) 92 93 Returns: 94 EmbeddingConfig instance 95 96 Example: 97 >>> config = EmbeddingConfig.from_dict({ 98 ... "base_url": "https://your-resource.openai.azure.com/", 99 ... "api_key": "your-api-key", 100 ... "model": "text-embedding-3-small", 101 ... "dimensions": 1536 102 ... }) 103 """ 104 base_url = data.get("base_url") 105 if not base_url: 106 raise ValueError("'base_url' is required") 107 108 model = data.get("model", "text-embedding-ada-002") 109 110 return cls( 111 base_url=base_url, 112 api_key=data["api_key"], 113 model=model, 114 api_version=data.get("api_version", "2024-02-15-preview"), 115 max_cache_size=data.get("max_cache_size", 10000), 116 dimensions=data.get("dimensions", 1536), 117 max_retries=data.get("max_retries", 5), 118 base_delay=data.get("base_delay", 1.0), 119 max_delay=data.get("max_delay", 60.0), 120 )
辞書からEmbeddingConfigを作成
Args: data: 設定辞書(base_url, api_key, model 等)
Returns: EmbeddingConfig instance
Example:
config = EmbeddingConfig.from_dict({ ... "base_url": "https://your-resource.openai.azure.com/", ... "api_key": "your-api-key", ... "model": "text-embedding-3-small", ... "dimensions": 1536 ... })
122 @classmethod 123 def from_toml(cls, toml_path: str, section: str = "rag.embedding") -> "EmbeddingConfig": 124 """TOMLファイルからEmbeddingConfigを作成 125 126 Args: 127 toml_path: TOMLファイルパス 128 section: セクション名(ドット区切りでネスト対応) 129 130 Returns: 131 EmbeddingConfig instance 132 133 Example: 134 >>> # config.toml の [rag.embedding] セクションを読み込み 135 >>> config = EmbeddingConfig.from_toml("config.toml") 136 """ 137 if not TOML_AVAILABLE: 138 raise ImportError("toml library is required: pip install toml") 139 140 with open(toml_path, "r") as f: 141 toml_data = toml.load(f) 142 143 # ネストされたセクションをドット区切りで辿る 144 data = toml_data 145 for key in section.split("."): 146 data = data[key] 147 148 return cls.from_dict(data)
TOMLファイルからEmbeddingConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: EmbeddingConfig instance
Example:
config.toml の [rag.embedding] セクションを読み込み
config = EmbeddingConfig.from_toml("config.toml")
170class QdrantManager: 171 """ 172 Qdrant操作クラス 173 174 ベクトルデータベースQdrantへのアクセスを提供します。 175 テキストの埋め込み生成とベクトル検索を統合的に扱えます。 176 177 Features: 178 - Collection自動作成(HNSW最適化設定) 179 - Payload Index作成(検索フィルタ用) 180 - Upsert/Batch Upsert (冪等性保証 - 同一IDは上書き) 181 - Semantic Search (フィルタ対応) 182 - 統計情報取得 183 184 Example: 185 >>> from agenticstar_platform import QdrantManager, QdrantConfig, EmbeddingGenerator, EmbeddingConfig 186 >>> 187 >>> # 設定を作成 188 >>> embedding_config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 189 >>> qdrant_config = QdrantConfig.from_toml("config.toml", section="rag.qdrant") 190 >>> 191 >>> # Embeddingジェネレーター作成 192 >>> embedding_gen = EmbeddingGenerator(embedding_config) 193 >>> 194 >>> # コンテキストマネージャーで使用(推奨) 195 >>> async with QdrantManager(qdrant_config, embedding_gen) as manager: 196 ... # データ追加 197 ... await manager.upsert( 198 ... id="doc-001", 199 ... content="AIエージェントの設計パターン", 200 ... metadata={"category": "technical", "author": "John"} 201 ... ) 202 ... 203 ... # 検索 204 ... results = await manager.search( 205 ... query_text="エージェント設計", 206 ... limit=5, 207 ... filter_conditions={"category": "technical"} 208 ... ) 209 ... for hit in results["data"]["results"]: 210 ... print(f"Score: {hit['score']:.3f} - {hit['payload']}") 211 """ 212 213 def __init__( 214 self, 215 config: QdrantConfig, 216 embedding_generator: EmbeddingGenerator, 217 ): 218 """ 219 Initialize Qdrant Manager 220 221 Args: 222 config: QdrantConfig instance 223 embedding_generator: EmbeddingGenerator instance 224 225 Raises: 226 QdrantConfigError: vector_sizeとembedding dimensionsの不一致 227 """ 228 # vector_size検証 229 if config.vector_size != embedding_generator.dimensions: 230 raise QdrantConfigError( 231 f"Vector size mismatch: QdrantConfig.vector_size={config.vector_size} " 232 f"!= EmbeddingConfig.dimensions={embedding_generator.dimensions}. " 233 f"These values must match for proper vector storage." 234 ) 235 236 # auth_token_provider設定時はプロキシ経由なのでgRPCを無効化 237 self.client = QdrantClient( 238 url=config.url, 239 auth_token_provider=config.auth_token_provider, 240 prefer_grpc=config.auth_token_provider is None, 241 ) 242 self.config = config 243 self.collection_name = config.collection_name 244 self.embedding_generator = embedding_generator 245 self._initialized = False 246 247 logger.info( 248 f"QdrantManager initialized - url={config.url}, " 249 f"collection={config.collection_name}, vector_size={config.vector_size}" 250 ) 251 252 async def __aenter__(self) -> "QdrantManager": 253 """Context Manager: 開始時に初期化""" 254 await self.initialize() 255 return self 256 257 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 258 """Context Manager: 終了時にクローズ""" 259 await self.close() 260 261 def is_initialized(self) -> bool: 262 """初期化されているかどうか""" 263 return self._initialized 264 265 async def ensure_initialized(self) -> None: 266 """初期化されていなければ初期化""" 267 if not self._initialized: 268 await self.initialize() 269 270 async def initialize(self) -> None: 271 """ 272 Collection存在確認・作成・Index作成 273 """ 274 if self._initialized: 275 logger.debug("QdrantManager already initialized") 276 return 277 278 logger.info(f"Initializing QdrantManager - collection={self.collection_name}") 279 280 try: 281 # Collection existence check (sync call → asyncio.to_thread) 282 collections_response = await asyncio.to_thread( 283 self.client.get_collections 284 ) 285 collection_names = [col.name for col in collections_response.collections] 286 287 if self.collection_name not in collection_names: 288 logger.info(f"Creating collection: {self.collection_name}") 289 await self._create_collection() 290 await self._create_payload_indices() 291 else: 292 logger.info(f"Collection already exists: {self.collection_name}") 293 294 self._initialized = True 295 logger.info("QdrantManager initialized successfully") 296 297 except Exception as e: 298 logger.error(f"Failed to initialize QdrantManager: {e}") 299 raise 300 301 async def _create_collection(self) -> None: 302 """ 303 Create Qdrant collection with optimized settings 304 """ 305 # Distance mapping 306 distance_map = { 307 "cosine": Distance.COSINE, 308 "euclid": Distance.EUCLID, 309 "dot": Distance.DOT, 310 } 311 distance = distance_map.get(self.config.distance, Distance.COSINE) 312 313 await asyncio.to_thread( 314 self.client.create_collection, 315 collection_name=self.collection_name, 316 vectors_config=VectorParams( 317 size=self.config.vector_size, 318 distance=distance, 319 on_disk=self.config.on_disk_vectors, 320 ), 321 hnsw_config=HnswConfigDiff( 322 m=self.config.hnsw_m, 323 ef_construct=self.config.hnsw_ef_construct, 324 full_scan_threshold=10000, 325 on_disk=True, 326 ), 327 optimizers_config=OptimizersConfigDiff( 328 memmap_threshold=20000, 329 ), 330 on_disk_payload=self.config.on_disk_payload, 331 ) 332 logger.info(f"Collection created: {self.collection_name}") 333 334 async def _create_payload_indices(self) -> None: 335 """ 336 Create payload indices for filtering performance (configurable) 337 """ 338 for index_config in self.config.payload_indexes: 339 await asyncio.to_thread( 340 self.client.create_payload_index, 341 collection_name=self.collection_name, 342 field_name=index_config.field_name, 343 field_schema=index_config.field_schema, 344 ) 345 logger.info(f"Payload index created: {index_config.field_name}") 346 347 async def create_payload_index( 348 self, 349 field_name: str, 350 field_schema: str = "keyword" 351 ) -> Dict[str, Any]: 352 """ 353 カスタムPayload Indexを作成 354 355 Args: 356 field_name: フィールド名 357 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool") 358 359 Returns: 360 Dict with success status 361 """ 362 try: 363 await asyncio.to_thread( 364 self.client.create_payload_index, 365 collection_name=self.collection_name, 366 field_name=field_name, 367 field_schema=field_schema, 368 ) 369 logger.info(f"Payload index created: {field_name}") 370 return {"success": True, "field_name": field_name} 371 except Exception as e: 372 logger.error(f"Failed to create payload index: {e}") 373 return {"success": False, "error": str(e), "error_code": "INDEX_CREATE_ERROR"} 374 375 async def upsert( 376 self, 377 id: Optional[str] = None, 378 content: Optional[str] = None, 379 metadata: Optional[Dict[str, Any]] = None, 380 ) -> Dict[str, Any]: 381 """ 382 ベクトルデータをUpsert(冪等性保証) 383 384 テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 385 同一IDが既に存在する場合は上書きされます(冪等性保証)。 386 387 Args: 388 id: ドキュメントID(省略時は自動UUID生成) 389 content: 埋め込み対象のテキスト 390 metadata: メタデータ(検索時のフィルタリングや結果表示に使用) 391 392 Returns: 393 Dict[str, Any]: 実行結果 394 - success: bool 395 - data: {"id": str} (成功時) 396 - error: str, error_code: str (失敗時) 397 398 Example: 399 >>> result = await manager.upsert( 400 ... id="doc-001", # 省略可(自動UUID生成) 401 ... content="Pythonは汎用プログラミング言語です。", 402 ... metadata={"title": "Python入門", "category": "programming"} 403 ... ) 404 >>> if result["success"]: 405 ... print(f"Added: {result['data']['id']}") 406 407 >>> # ID自動生成 408 >>> result = await manager.upsert( 409 ... content="テキスト内容", 410 ... metadata={"source": "web"} 411 ... ) 412 >>> print(f"Auto-generated ID: {result['data']['id']}") 413 """ 414 await self.ensure_initialized() 415 416 import uuid as uuid_module 417 actual_id = id or str(uuid_module.uuid4()) 418 actual_metadata = metadata or {} 419 420 if not content: 421 return { 422 "success": False, 423 "error": "'content' is required", 424 "error_code": "MISSING_CONTENT", 425 } 426 427 try: 428 logger.debug(f"Upserting point - id={actual_id}") 429 430 # Generate embedding 431 embedding = await self.embedding_generator.generate(content) 432 433 # Add default metadata 434 full_payload = { 435 **actual_metadata, 436 "created_at": int(datetime.now().timestamp()), 437 "model_version": self.embedding_generator.model_name, 438 } 439 440 # Build Point 441 point = PointStruct( 442 id=actual_id, 443 vector=embedding, 444 payload=full_payload, 445 ) 446 447 # Upsert (idempotent - same ID overwrites, sync call → asyncio.to_thread) 448 await asyncio.to_thread( 449 self.client.upsert, 450 collection_name=self.collection_name, 451 points=[point], 452 ) 453 454 logger.info(f"Point upserted - id={actual_id}") 455 return {"success": True, "data": {"id": actual_id}} 456 457 except Exception as e: 458 logger.error(f"Failed to upsert point: {e}") 459 return { 460 "success": False, 461 "error": f"Upsert error: {str(e)}", 462 "error_code": "QDRANT_UPSERT_ERROR", 463 } 464 465 async def batch_upsert( 466 self, 467 items: List[Dict[str, Any]], 468 batch_size: int = 256, 469 ) -> Dict[str, Any]: 470 """ 471 バッチUpsert 472 473 Args: 474 items: List of {"id": str, "content": str, "metadata": dict} 475 batch_size: バッチサイズ 476 477 Returns: 478 Dict with success status and total_upserted 479 """ 480 await self.ensure_initialized() 481 482 try: 483 logger.info(f"Batch upserting {len(items)} points") 484 485 total_points = 0 486 487 for i in range(0, len(items), batch_size): 488 batch = items[i : i + batch_size] 489 490 # Extract texts and generate embeddings 491 texts = [item["content"] for item in batch] 492 embeddings = await self.embedding_generator.batch_generate(texts) 493 494 # Build Points 495 points = [] 496 for item, embedding in zip(batch, embeddings): 497 full_payload = { 498 **item.get("metadata", {}), 499 "created_at": int(datetime.now().timestamp()), 500 "model_version": self.embedding_generator.model_name, 501 } 502 points.append( 503 PointStruct( 504 id=item["id"], 505 vector=embedding, 506 payload=full_payload, 507 ) 508 ) 509 510 # Batch Upsert (sync call → asyncio.to_thread) 511 await asyncio.to_thread( 512 self.client.upsert, 513 collection_name=self.collection_name, 514 points=points, 515 ) 516 total_points += len(points) 517 518 logger.debug(f"Batch upserted {len(points)} points") 519 520 logger.info(f"Batch upsert completed - total={total_points}") 521 return {"success": True, "data": {"total_upserted": total_points}} 522 523 except Exception as e: 524 logger.error(f"Failed to batch upsert: {e}") 525 return { 526 "success": False, 527 "error": f"Batch upsert error: {str(e)}", 528 "error_code": "QDRANT_BATCH_UPSERT_ERROR", 529 } 530 531 async def search( 532 self, 533 query_text: str, 534 limit: int = 10, 535 score_threshold: float = 0.4, 536 filter_conditions: Optional[Dict[str, Any]] = None, 537 ef: Optional[int] = None, 538 ) -> Dict[str, Any]: 539 """ 540 セマンティック検索 541 542 クエリテキストに意味的に近いドキュメントを検索します。 543 フィルタ条件を指定して結果を絞り込むことも可能です。 544 545 Args: 546 query_text: 検索クエリ。自然言語で記述 547 limit: 結果数上限(デフォルト: 10) 548 score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) 549 filter_conditions: フィルタ条件 {"field": value} 形式(オプション) 550 ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる 551 552 Returns: 553 Dict[str, Any]: 検索結果 554 - success: bool 555 - data: { 556 "query": str, 557 "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], 558 "total_found": int, 559 "score_threshold": float 560 } 561 - error: str, error_code: str (失敗時) 562 563 Example: 564 >>> # 基本的な検索 565 >>> results = await manager.search( 566 ... query_text="プログラミング言語", 567 ... limit=5 568 ... ) 569 >>> for hit in results["data"]["results"]: 570 ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}") 571 >>> 572 >>> # フィルタ付き検索 573 >>> results = await manager.search( 574 ... query_text="データ分析", 575 ... limit=10, 576 ... score_threshold=0.5, 577 ... filter_conditions={"category": "programming"} 578 ... ) 579 >>> print(f"Found {results['data']['total_found']} results") 580 """ 581 await self.ensure_initialized() 582 583 try: 584 logger.debug( 585 f"Searching - query_len={len(query_text)}, limit={limit}" 586 ) 587 588 # Generate query embedding 589 query_vector = await self.embedding_generator.generate(query_text) 590 591 # Build filter 592 query_filter = None 593 if filter_conditions: 594 must_conditions = [ 595 FieldCondition( 596 key=key, 597 match=MatchValue(value=value), 598 ) 599 for key, value in filter_conditions.items() 600 ] 601 query_filter = Filter(must=must_conditions) 602 603 # Search params 604 search_params = SearchParams( 605 hnsw_ef=ef or (limit * 2), 606 exact=False, 607 ) 608 609 # Execute search (sync call → asyncio.to_thread) 610 results = await asyncio.to_thread( 611 self.client.query_points, 612 collection_name=self.collection_name, 613 query=query_vector, 614 query_filter=query_filter, 615 limit=limit, 616 score_threshold=score_threshold, 617 search_params=search_params, 618 with_payload=True, 619 ) 620 621 # Format results 622 formatted_results = [ 623 { 624 "point_id": hit.id, 625 "payload": hit.payload, 626 "score": hit.score, 627 "similarity": self._score_to_similarity(hit.score), 628 } 629 for hit in results.points 630 ] 631 632 logger.info( 633 f"Search completed - found={len(formatted_results)}, " 634 f"query_len={len(query_text)}" 635 ) 636 return { 637 "success": True, 638 "data": { 639 "query": query_text, 640 "results": formatted_results, 641 "total_found": len(formatted_results), 642 "score_threshold": score_threshold, 643 }, 644 } 645 646 except Exception as e: 647 logger.error(f"Failed to search: {e}") 648 return { 649 "success": False, 650 "error": f"Search error: {str(e)}", 651 "error_code": "QDRANT_SEARCH_ERROR", 652 } 653 654 def _score_to_similarity(self, score: float) -> float: 655 """ 656 Qdrant cosine score (-1〜1) → 類似度 (0〜1) 変換 657 """ 658 return max(0.0, min(1.0, (score + 1.0) / 2.0)) 659 660 async def delete(self, point_ids: List[str]) -> Dict[str, Any]: 661 """ 662 ポイントを削除 663 664 指定したIDのベクトルデータを削除します。 665 666 Args: 667 point_ids: 削除するポイントIDのリスト 668 669 Returns: 670 Dict[str, Any]: 削除結果 671 - success: bool 672 - data: {"deleted_count": int} (成功時) 673 - error: str, error_code: str (失敗時) 674 675 Example: 676 >>> # 単一ポイントを削除 677 >>> result = await manager.delete(["doc-001"]) 678 >>> print(f"Deleted: {result['data']['deleted_count']}") 679 >>> 680 >>> # 複数ポイントを削除 681 >>> result = await manager.delete(["doc-001", "doc-002", "doc-003"]) 682 >>> if result["success"]: 683 ... print(f"Deleted {result['data']['deleted_count']} documents") 684 """ 685 await self.ensure_initialized() 686 687 try: 688 # Sync call → asyncio.to_thread 689 await asyncio.to_thread( 690 self.client.delete, 691 collection_name=self.collection_name, 692 points_selector=point_ids, 693 ) 694 logger.info(f"Deleted {len(point_ids)} points") 695 return {"success": True, "data": {"deleted_count": len(point_ids)}} 696 697 except Exception as e: 698 logger.error(f"Failed to delete points: {e}") 699 return { 700 "success": False, 701 "error": f"Delete error: {str(e)}", 702 "error_code": "QDRANT_DELETE_ERROR", 703 } 704 705 async def get_statistics(self) -> Dict[str, Any]: 706 """ 707 コレクション統計情報を取得 708 709 コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。 710 711 Returns: 712 Dict[str, Any]: 統計情報 713 - success: bool 714 - data: { 715 "total_objects": int, 716 "vectors_count": int, 717 "content_type_breakdown": dict, 718 "collection_name": str, 719 "status": str 720 } 721 - error: str, error_code: str (失敗時) 722 723 Example: 724 >>> stats = await manager.get_statistics() 725 >>> if stats["success"]: 726 ... data = stats["data"] 727 ... print(f"Total documents: {data['total_objects']}") 728 ... print(f"Status: {data['status']}") 729 ... print("Content types:") 730 ... for ctype, count in data["content_type_breakdown"].items(): 731 ... print(f" {ctype}: {count}") 732 """ 733 await self.ensure_initialized() 734 735 try: 736 # Sync calls → asyncio.to_thread 737 collection_info = await asyncio.to_thread( 738 self.client.get_collection, 739 self.collection_name, 740 ) 741 742 # Content type breakdown (sample 1000 points) 743 scroll_result = await asyncio.to_thread( 744 self.client.scroll, 745 collection_name=self.collection_name, 746 limit=1000, 747 with_payload=["content_type"], 748 with_vectors=False, 749 ) 750 751 content_type_breakdown = {} 752 for point in scroll_result[0]: 753 content_type = point.payload.get("content_type", "unknown") 754 content_type_breakdown[content_type] = ( 755 content_type_breakdown.get(content_type, 0) + 1 756 ) 757 758 stats = { 759 "total_objects": collection_info.points_count, 760 "vectors_count": collection_info.vectors_count, 761 "content_type_breakdown": content_type_breakdown, 762 "collection_name": self.collection_name, 763 "status": collection_info.status.value, 764 } 765 766 logger.info(f"Statistics retrieved - total={stats['total_objects']}") 767 return {"success": True, "data": stats} 768 769 except Exception as e: 770 logger.error(f"Failed to get statistics: {e}") 771 return { 772 "success": False, 773 "error": f"Statistics error: {str(e)}", 774 "error_code": "QDRANT_STATS_ERROR", 775 } 776 777 async def delete_collection(self, collection_name: Optional[str] = None) -> Dict[str, Any]: 778 """ 779 コレクションを削除 780 781 Args: 782 collection_name: 削除するコレクション名(省略時は現在のコレクション) 783 784 Returns: 785 Dict[str, Any]: 削除結果 786 - success: bool 787 - data: {"collection_name": str} (成功時) 788 - error: str, error_code: str (失敗時) 789 790 Example: 791 >>> # 現在のコレクションを削除 792 >>> result = await manager.delete_collection() 793 >>> if result["success"]: 794 ... print(f"Deleted: {result['data']['collection_name']}") 795 796 >>> # 別のコレクションを削除 797 >>> result = await manager.delete_collection("old_collection") 798 """ 799 target_collection = collection_name or self.collection_name 800 801 try: 802 await asyncio.to_thread( 803 self.client.delete_collection, 804 target_collection, 805 ) 806 logger.info(f"Collection deleted: {target_collection}") 807 808 # 現在のコレクションを削除した場合は初期化状態をリセット 809 if target_collection == self.collection_name: 810 self._initialized = False 811 812 return {"success": True, "data": {"collection_name": target_collection}} 813 814 except Exception as e: 815 logger.error(f"Failed to delete collection: {e}") 816 return { 817 "success": False, 818 "error": f"Delete collection error: {str(e)}", 819 "error_code": "QDRANT_DELETE_COLLECTION_ERROR", 820 } 821 822 async def collection_exists(self, collection_name: Optional[str] = None) -> bool: 823 """ 824 コレクションが存在するか確認 825 826 Args: 827 collection_name: 確認するコレクション名(省略時は現在のコレクション) 828 829 Returns: 830 bool: コレクションが存在する場合True 831 832 Example: 833 >>> if await manager.collection_exists(): 834 ... print("Collection exists") 835 >>> else: 836 ... print("Collection does not exist") 837 """ 838 target_collection = collection_name or self.collection_name 839 840 try: 841 collections_response = await asyncio.to_thread( 842 self.client.get_collections 843 ) 844 collection_names = [col.name for col in collections_response.collections] 845 return target_collection in collection_names 846 847 except Exception as e: 848 logger.error(f"Failed to check collection existence: {e}") 849 return False 850 851 async def close(self) -> None: 852 """ 853 Close Qdrant client 854 855 Note: 856 例外が発生してもリソースを確実にクリーンアップします。 857 """ 858 try: 859 logger.info("Closing QdrantManager") 860 # Qdrantクライアントのクローズ処理(必要に応じて) 861 if self.client and hasattr(self.client, 'close'): 862 self.client.close() 863 except Exception as e: 864 logger.warning(f"Error closing Qdrant client: {e}") 865 finally: 866 self.client = None 867 self._initialized = False 868 logger.debug("QdrantManager closed")
Qdrant操作クラス
ベクトルデータベースQdrantへのアクセスを提供します。 テキストの埋め込み生成とベクトル検索を統合的に扱えます。
Features: - Collection自動作成(HNSW最適化設定) - Payload Index作成(検索フィルタ用) - Upsert/Batch Upsert (冪等性保証 - 同一IDは上書き) - Semantic Search (フィルタ対応) - 統計情報取得
Example:
from agenticstar_platform import QdrantManager, QdrantConfig, EmbeddingGenerator, EmbeddingConfig
設定を作成
embedding_config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") qdrant_config = QdrantConfig.from_toml("config.toml", section="rag.qdrant")
Embeddingジェネレーター作成
embedding_gen = EmbeddingGenerator(embedding_config)
コンテキストマネージャーで使用(推奨)
async with QdrantManager(qdrant_config, embedding_gen) as manager: ... # データ追加 ... await manager.upsert( ... id="doc-001", ... content="AIエージェントの設計パターン", ... metadata={"category": "technical", "author": "John"} ... ) ... ... # 検索 ... results = await manager.search( ... query_text="エージェント設計", ... limit=5, ... filter_conditions={"category": "technical"} ... ) ... for hit in results["data"]["results"]: ... print(f"Score: {hit['score']:.3f} - {hit['payload']}")
213 def __init__( 214 self, 215 config: QdrantConfig, 216 embedding_generator: EmbeddingGenerator, 217 ): 218 """ 219 Initialize Qdrant Manager 220 221 Args: 222 config: QdrantConfig instance 223 embedding_generator: EmbeddingGenerator instance 224 225 Raises: 226 QdrantConfigError: vector_sizeとembedding dimensionsの不一致 227 """ 228 # vector_size検証 229 if config.vector_size != embedding_generator.dimensions: 230 raise QdrantConfigError( 231 f"Vector size mismatch: QdrantConfig.vector_size={config.vector_size} " 232 f"!= EmbeddingConfig.dimensions={embedding_generator.dimensions}. " 233 f"These values must match for proper vector storage." 234 ) 235 236 # auth_token_provider設定時はプロキシ経由なのでgRPCを無効化 237 self.client = QdrantClient( 238 url=config.url, 239 auth_token_provider=config.auth_token_provider, 240 prefer_grpc=config.auth_token_provider is None, 241 ) 242 self.config = config 243 self.collection_name = config.collection_name 244 self.embedding_generator = embedding_generator 245 self._initialized = False 246 247 logger.info( 248 f"QdrantManager initialized - url={config.url}, " 249 f"collection={config.collection_name}, vector_size={config.vector_size}" 250 )
Initialize Qdrant Manager
Args: config: QdrantConfig instance embedding_generator: EmbeddingGenerator instance
Raises: QdrantConfigError: vector_sizeとembedding dimensionsの不一致
265 async def ensure_initialized(self) -> None: 266 """初期化されていなければ初期化""" 267 if not self._initialized: 268 await self.initialize()
初期化されていなければ初期化
270 async def initialize(self) -> None: 271 """ 272 Collection存在確認・作成・Index作成 273 """ 274 if self._initialized: 275 logger.debug("QdrantManager already initialized") 276 return 277 278 logger.info(f"Initializing QdrantManager - collection={self.collection_name}") 279 280 try: 281 # Collection existence check (sync call → asyncio.to_thread) 282 collections_response = await asyncio.to_thread( 283 self.client.get_collections 284 ) 285 collection_names = [col.name for col in collections_response.collections] 286 287 if self.collection_name not in collection_names: 288 logger.info(f"Creating collection: {self.collection_name}") 289 await self._create_collection() 290 await self._create_payload_indices() 291 else: 292 logger.info(f"Collection already exists: {self.collection_name}") 293 294 self._initialized = True 295 logger.info("QdrantManager initialized successfully") 296 297 except Exception as e: 298 logger.error(f"Failed to initialize QdrantManager: {e}") 299 raise
Collection存在確認・作成・Index作成
347 async def create_payload_index( 348 self, 349 field_name: str, 350 field_schema: str = "keyword" 351 ) -> Dict[str, Any]: 352 """ 353 カスタムPayload Indexを作成 354 355 Args: 356 field_name: フィールド名 357 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool") 358 359 Returns: 360 Dict with success status 361 """ 362 try: 363 await asyncio.to_thread( 364 self.client.create_payload_index, 365 collection_name=self.collection_name, 366 field_name=field_name, 367 field_schema=field_schema, 368 ) 369 logger.info(f"Payload index created: {field_name}") 370 return {"success": True, "field_name": field_name} 371 except Exception as e: 372 logger.error(f"Failed to create payload index: {e}") 373 return {"success": False, "error": str(e), "error_code": "INDEX_CREATE_ERROR"}
カスタムPayload Indexを作成
Args: field_name: フィールド名 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool")
Returns: Dict with success status
375 async def upsert( 376 self, 377 id: Optional[str] = None, 378 content: Optional[str] = None, 379 metadata: Optional[Dict[str, Any]] = None, 380 ) -> Dict[str, Any]: 381 """ 382 ベクトルデータをUpsert(冪等性保証) 383 384 テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 385 同一IDが既に存在する場合は上書きされます(冪等性保証)。 386 387 Args: 388 id: ドキュメントID(省略時は自動UUID生成) 389 content: 埋め込み対象のテキスト 390 metadata: メタデータ(検索時のフィルタリングや結果表示に使用) 391 392 Returns: 393 Dict[str, Any]: 実行結果 394 - success: bool 395 - data: {"id": str} (成功時) 396 - error: str, error_code: str (失敗時) 397 398 Example: 399 >>> result = await manager.upsert( 400 ... id="doc-001", # 省略可(自動UUID生成) 401 ... content="Pythonは汎用プログラミング言語です。", 402 ... metadata={"title": "Python入門", "category": "programming"} 403 ... ) 404 >>> if result["success"]: 405 ... print(f"Added: {result['data']['id']}") 406 407 >>> # ID自動生成 408 >>> result = await manager.upsert( 409 ... content="テキスト内容", 410 ... metadata={"source": "web"} 411 ... ) 412 >>> print(f"Auto-generated ID: {result['data']['id']}") 413 """ 414 await self.ensure_initialized() 415 416 import uuid as uuid_module 417 actual_id = id or str(uuid_module.uuid4()) 418 actual_metadata = metadata or {} 419 420 if not content: 421 return { 422 "success": False, 423 "error": "'content' is required", 424 "error_code": "MISSING_CONTENT", 425 } 426 427 try: 428 logger.debug(f"Upserting point - id={actual_id}") 429 430 # Generate embedding 431 embedding = await self.embedding_generator.generate(content) 432 433 # Add default metadata 434 full_payload = { 435 **actual_metadata, 436 "created_at": int(datetime.now().timestamp()), 437 "model_version": self.embedding_generator.model_name, 438 } 439 440 # Build Point 441 point = PointStruct( 442 id=actual_id, 443 vector=embedding, 444 payload=full_payload, 445 ) 446 447 # Upsert (idempotent - same ID overwrites, sync call → asyncio.to_thread) 448 await asyncio.to_thread( 449 self.client.upsert, 450 collection_name=self.collection_name, 451 points=[point], 452 ) 453 454 logger.info(f"Point upserted - id={actual_id}") 455 return {"success": True, "data": {"id": actual_id}} 456 457 except Exception as e: 458 logger.error(f"Failed to upsert point: {e}") 459 return { 460 "success": False, 461 "error": f"Upsert error: {str(e)}", 462 "error_code": "QDRANT_UPSERT_ERROR", 463 }
ベクトルデータをUpsert(冪等性保証)
テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 同一IDが既に存在する場合は上書きされます(冪等性保証)。
Args: id: ドキュメントID(省略時は自動UUID生成) content: 埋め込み対象のテキスト metadata: メタデータ(検索時のフィルタリングや結果表示に使用)
Returns: Dict[str, Any]: 実行結果 - success: bool - data: {"id": str} (成功時) - error: str, error_code: str (失敗時)
Example:
result = await manager.upsert( ... id="doc-001", # 省略可(自動UUID生成) ... content="Pythonは汎用プログラミング言語です。", ... metadata={"title": "Python入門", "category": "programming"} ... ) if result["success"]: ... print(f"Added: {result['data']['id']}")
>>> # ID自動生成 >>> result = await manager.upsert( ... content="テキスト内容", ... metadata={"source": "web"} ... ) >>> print(f"Auto-generated ID: {result['data']['id']}")
465 async def batch_upsert( 466 self, 467 items: List[Dict[str, Any]], 468 batch_size: int = 256, 469 ) -> Dict[str, Any]: 470 """ 471 バッチUpsert 472 473 Args: 474 items: List of {"id": str, "content": str, "metadata": dict} 475 batch_size: バッチサイズ 476 477 Returns: 478 Dict with success status and total_upserted 479 """ 480 await self.ensure_initialized() 481 482 try: 483 logger.info(f"Batch upserting {len(items)} points") 484 485 total_points = 0 486 487 for i in range(0, len(items), batch_size): 488 batch = items[i : i + batch_size] 489 490 # Extract texts and generate embeddings 491 texts = [item["content"] for item in batch] 492 embeddings = await self.embedding_generator.batch_generate(texts) 493 494 # Build Points 495 points = [] 496 for item, embedding in zip(batch, embeddings): 497 full_payload = { 498 **item.get("metadata", {}), 499 "created_at": int(datetime.now().timestamp()), 500 "model_version": self.embedding_generator.model_name, 501 } 502 points.append( 503 PointStruct( 504 id=item["id"], 505 vector=embedding, 506 payload=full_payload, 507 ) 508 ) 509 510 # Batch Upsert (sync call → asyncio.to_thread) 511 await asyncio.to_thread( 512 self.client.upsert, 513 collection_name=self.collection_name, 514 points=points, 515 ) 516 total_points += len(points) 517 518 logger.debug(f"Batch upserted {len(points)} points") 519 520 logger.info(f"Batch upsert completed - total={total_points}") 521 return {"success": True, "data": {"total_upserted": total_points}} 522 523 except Exception as e: 524 logger.error(f"Failed to batch upsert: {e}") 525 return { 526 "success": False, 527 "error": f"Batch upsert error: {str(e)}", 528 "error_code": "QDRANT_BATCH_UPSERT_ERROR", 529 }
バッチUpsert
Args: items: List of {"id": str, "content": str, "metadata": dict} batch_size: バッチサイズ
Returns: Dict with success status and total_upserted
531 async def search( 532 self, 533 query_text: str, 534 limit: int = 10, 535 score_threshold: float = 0.4, 536 filter_conditions: Optional[Dict[str, Any]] = None, 537 ef: Optional[int] = None, 538 ) -> Dict[str, Any]: 539 """ 540 セマンティック検索 541 542 クエリテキストに意味的に近いドキュメントを検索します。 543 フィルタ条件を指定して結果を絞り込むことも可能です。 544 545 Args: 546 query_text: 検索クエリ。自然言語で記述 547 limit: 結果数上限(デフォルト: 10) 548 score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) 549 filter_conditions: フィルタ条件 {"field": value} 形式(オプション) 550 ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる 551 552 Returns: 553 Dict[str, Any]: 検索結果 554 - success: bool 555 - data: { 556 "query": str, 557 "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], 558 "total_found": int, 559 "score_threshold": float 560 } 561 - error: str, error_code: str (失敗時) 562 563 Example: 564 >>> # 基本的な検索 565 >>> results = await manager.search( 566 ... query_text="プログラミング言語", 567 ... limit=5 568 ... ) 569 >>> for hit in results["data"]["results"]: 570 ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}") 571 >>> 572 >>> # フィルタ付き検索 573 >>> results = await manager.search( 574 ... query_text="データ分析", 575 ... limit=10, 576 ... score_threshold=0.5, 577 ... filter_conditions={"category": "programming"} 578 ... ) 579 >>> print(f"Found {results['data']['total_found']} results") 580 """ 581 await self.ensure_initialized() 582 583 try: 584 logger.debug( 585 f"Searching - query_len={len(query_text)}, limit={limit}" 586 ) 587 588 # Generate query embedding 589 query_vector = await self.embedding_generator.generate(query_text) 590 591 # Build filter 592 query_filter = None 593 if filter_conditions: 594 must_conditions = [ 595 FieldCondition( 596 key=key, 597 match=MatchValue(value=value), 598 ) 599 for key, value in filter_conditions.items() 600 ] 601 query_filter = Filter(must=must_conditions) 602 603 # Search params 604 search_params = SearchParams( 605 hnsw_ef=ef or (limit * 2), 606 exact=False, 607 ) 608 609 # Execute search (sync call → asyncio.to_thread) 610 results = await asyncio.to_thread( 611 self.client.query_points, 612 collection_name=self.collection_name, 613 query=query_vector, 614 query_filter=query_filter, 615 limit=limit, 616 score_threshold=score_threshold, 617 search_params=search_params, 618 with_payload=True, 619 ) 620 621 # Format results 622 formatted_results = [ 623 { 624 "point_id": hit.id, 625 "payload": hit.payload, 626 "score": hit.score, 627 "similarity": self._score_to_similarity(hit.score), 628 } 629 for hit in results.points 630 ] 631 632 logger.info( 633 f"Search completed - found={len(formatted_results)}, " 634 f"query_len={len(query_text)}" 635 ) 636 return { 637 "success": True, 638 "data": { 639 "query": query_text, 640 "results": formatted_results, 641 "total_found": len(formatted_results), 642 "score_threshold": score_threshold, 643 }, 644 } 645 646 except Exception as e: 647 logger.error(f"Failed to search: {e}") 648 return { 649 "success": False, 650 "error": f"Search error: {str(e)}", 651 "error_code": "QDRANT_SEARCH_ERROR", 652 }
セマンティック検索
クエリテキストに意味的に近いドキュメントを検索します。 フィルタ条件を指定して結果を絞り込むことも可能です。
Args: query_text: 検索クエリ。自然言語で記述 limit: 結果数上限(デフォルト: 10) score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) filter_conditions: フィルタ条件 {"field": value} 形式(オプション) ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる
Returns: Dict[str, Any]: 検索結果 - success: bool - data: { "query": str, "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], "total_found": int, "score_threshold": float } - error: str, error_code: str (失敗時)
Example:
基本的な検索
results = await manager.search( ... query_text="プログラミング言語", ... limit=5 ... ) for hit in results["data"]["results"]: ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}")
フィルタ付き検索
results = await manager.search( ... query_text="データ分析", ... limit=10, ... score_threshold=0.5, ... filter_conditions={"category": "programming"} ... ) print(f"Found {results['data']['total_found']} results")
660 async def delete(self, point_ids: List[str]) -> Dict[str, Any]: 661 """ 662 ポイントを削除 663 664 指定したIDのベクトルデータを削除します。 665 666 Args: 667 point_ids: 削除するポイントIDのリスト 668 669 Returns: 670 Dict[str, Any]: 削除結果 671 - success: bool 672 - data: {"deleted_count": int} (成功時) 673 - error: str, error_code: str (失敗時) 674 675 Example: 676 >>> # 単一ポイントを削除 677 >>> result = await manager.delete(["doc-001"]) 678 >>> print(f"Deleted: {result['data']['deleted_count']}") 679 >>> 680 >>> # 複数ポイントを削除 681 >>> result = await manager.delete(["doc-001", "doc-002", "doc-003"]) 682 >>> if result["success"]: 683 ... print(f"Deleted {result['data']['deleted_count']} documents") 684 """ 685 await self.ensure_initialized() 686 687 try: 688 # Sync call → asyncio.to_thread 689 await asyncio.to_thread( 690 self.client.delete, 691 collection_name=self.collection_name, 692 points_selector=point_ids, 693 ) 694 logger.info(f"Deleted {len(point_ids)} points") 695 return {"success": True, "data": {"deleted_count": len(point_ids)}} 696 697 except Exception as e: 698 logger.error(f"Failed to delete points: {e}") 699 return { 700 "success": False, 701 "error": f"Delete error: {str(e)}", 702 "error_code": "QDRANT_DELETE_ERROR", 703 }
ポイントを削除
指定したIDのベクトルデータを削除します。
Args: point_ids: 削除するポイントIDのリスト
Returns: Dict[str, Any]: 削除結果 - success: bool - data: {"deleted_count": int} (成功時) - error: str, error_code: str (失敗時)
Example:
単一ポイントを削除
result = await manager.delete(["doc-001"]) print(f"Deleted: {result['data']['deleted_count']}")
複数ポイントを削除
result = await manager.delete(["doc-001", "doc-002", "doc-003"]) if result["success"]: ... print(f"Deleted {result['data']['deleted_count']} documents")
705 async def get_statistics(self) -> Dict[str, Any]: 706 """ 707 コレクション統計情報を取得 708 709 コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。 710 711 Returns: 712 Dict[str, Any]: 統計情報 713 - success: bool 714 - data: { 715 "total_objects": int, 716 "vectors_count": int, 717 "content_type_breakdown": dict, 718 "collection_name": str, 719 "status": str 720 } 721 - error: str, error_code: str (失敗時) 722 723 Example: 724 >>> stats = await manager.get_statistics() 725 >>> if stats["success"]: 726 ... data = stats["data"] 727 ... print(f"Total documents: {data['total_objects']}") 728 ... print(f"Status: {data['status']}") 729 ... print("Content types:") 730 ... for ctype, count in data["content_type_breakdown"].items(): 731 ... print(f" {ctype}: {count}") 732 """ 733 await self.ensure_initialized() 734 735 try: 736 # Sync calls → asyncio.to_thread 737 collection_info = await asyncio.to_thread( 738 self.client.get_collection, 739 self.collection_name, 740 ) 741 742 # Content type breakdown (sample 1000 points) 743 scroll_result = await asyncio.to_thread( 744 self.client.scroll, 745 collection_name=self.collection_name, 746 limit=1000, 747 with_payload=["content_type"], 748 with_vectors=False, 749 ) 750 751 content_type_breakdown = {} 752 for point in scroll_result[0]: 753 content_type = point.payload.get("content_type", "unknown") 754 content_type_breakdown[content_type] = ( 755 content_type_breakdown.get(content_type, 0) + 1 756 ) 757 758 stats = { 759 "total_objects": collection_info.points_count, 760 "vectors_count": collection_info.vectors_count, 761 "content_type_breakdown": content_type_breakdown, 762 "collection_name": self.collection_name, 763 "status": collection_info.status.value, 764 } 765 766 logger.info(f"Statistics retrieved - total={stats['total_objects']}") 767 return {"success": True, "data": stats} 768 769 except Exception as e: 770 logger.error(f"Failed to get statistics: {e}") 771 return { 772 "success": False, 773 "error": f"Statistics error: {str(e)}", 774 "error_code": "QDRANT_STATS_ERROR", 775 }
コレクション統計情報を取得
コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。
Returns: Dict[str, Any]: 統計情報 - success: bool - data: { "total_objects": int, "vectors_count": int, "content_type_breakdown": dict, "collection_name": str, "status": str } - error: str, error_code: str (失敗時)
Example:
stats = await manager.get_statistics() if stats["success"]: ... data = stats["data"] ... print(f"Total documents: {data['total_objects']}") ... print(f"Status: {data['status']}") ... print("Content types:") ... for ctype, count in data["content_type_breakdown"].items(): ... print(f" {ctype}: {count}")
777 async def delete_collection(self, collection_name: Optional[str] = None) -> Dict[str, Any]: 778 """ 779 コレクションを削除 780 781 Args: 782 collection_name: 削除するコレクション名(省略時は現在のコレクション) 783 784 Returns: 785 Dict[str, Any]: 削除結果 786 - success: bool 787 - data: {"collection_name": str} (成功時) 788 - error: str, error_code: str (失敗時) 789 790 Example: 791 >>> # 現在のコレクションを削除 792 >>> result = await manager.delete_collection() 793 >>> if result["success"]: 794 ... print(f"Deleted: {result['data']['collection_name']}") 795 796 >>> # 別のコレクションを削除 797 >>> result = await manager.delete_collection("old_collection") 798 """ 799 target_collection = collection_name or self.collection_name 800 801 try: 802 await asyncio.to_thread( 803 self.client.delete_collection, 804 target_collection, 805 ) 806 logger.info(f"Collection deleted: {target_collection}") 807 808 # 現在のコレクションを削除した場合は初期化状態をリセット 809 if target_collection == self.collection_name: 810 self._initialized = False 811 812 return {"success": True, "data": {"collection_name": target_collection}} 813 814 except Exception as e: 815 logger.error(f"Failed to delete collection: {e}") 816 return { 817 "success": False, 818 "error": f"Delete collection error: {str(e)}", 819 "error_code": "QDRANT_DELETE_COLLECTION_ERROR", 820 }
コレクションを削除
Args: collection_name: 削除するコレクション名(省略時は現在のコレクション)
Returns: Dict[str, Any]: 削除結果 - success: bool - data: {"collection_name": str} (成功時) - error: str, error_code: str (失敗時)
Example:
現在のコレクションを削除
result = await manager.delete_collection() if result["success"]: ... print(f"Deleted: {result['data']['collection_name']}")
>>> # 別のコレクションを削除 >>> result = await manager.delete_collection("old_collection")
822 async def collection_exists(self, collection_name: Optional[str] = None) -> bool: 823 """ 824 コレクションが存在するか確認 825 826 Args: 827 collection_name: 確認するコレクション名(省略時は現在のコレクション) 828 829 Returns: 830 bool: コレクションが存在する場合True 831 832 Example: 833 >>> if await manager.collection_exists(): 834 ... print("Collection exists") 835 >>> else: 836 ... print("Collection does not exist") 837 """ 838 target_collection = collection_name or self.collection_name 839 840 try: 841 collections_response = await asyncio.to_thread( 842 self.client.get_collections 843 ) 844 collection_names = [col.name for col in collections_response.collections] 845 return target_collection in collection_names 846 847 except Exception as e: 848 logger.error(f"Failed to check collection existence: {e}") 849 return False
コレクションが存在するか確認
Args: collection_name: 確認するコレクション名(省略時は現在のコレクション)
Returns: bool: コレクションが存在する場合True
Example:
if await manager.collection_exists(): ... print("Collection exists") else: ... print("Collection does not exist")
851 async def close(self) -> None: 852 """ 853 Close Qdrant client 854 855 Note: 856 例外が発生してもリソースを確実にクリーンアップします。 857 """ 858 try: 859 logger.info("Closing QdrantManager") 860 # Qdrantクライアントのクローズ処理(必要に応じて) 861 if self.client and hasattr(self.client, 'close'): 862 self.client.close() 863 except Exception as e: 864 logger.warning(f"Error closing Qdrant client: {e}") 865 finally: 866 self.client = None 867 self._initialized = False 868 logger.debug("QdrantManager closed")
Close Qdrant client
Note: 例外が発生してもリソースを確実にクリーンアップします。
64@dataclass 65class QdrantConfig: 66 """Qdrant設定 67 68 Example: 69 >>> # 辞書から作成 70 >>> config = QdrantConfig.from_dict({ 71 ... "url": "http://localhost:6333", 72 ... "collection_name": "my_collection", 73 ... "vector_size": 1536 74 ... }) 75 76 >>> # TOMLファイルから作成 77 >>> config = QdrantConfig.from_toml("config.toml", section="rag.qdrant") 78 """ 79 url: str 80 collection_name: str 81 vector_size: int = 1536 # text-embedding-ada-002 82 distance: str = "cosine" # cosine, euclid, dot 83 on_disk_vectors: bool = False 84 on_disk_payload: bool = True 85 hnsw_m: int = 16 86 hnsw_ef_construct: int = 256 87 # Payload Indexes(カスタマイズ可能) 88 payload_indexes: List[PayloadIndexConfig] = field(default_factory=lambda: [ 89 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 90 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 91 ]) 92 # CLIモード用認証トークン取得関数(QdrantClientのauth_token_providerに渡す) 93 # 設定時はプロキシ経由のためgRPCを無効化 94 auth_token_provider: Optional[Callable[[], str]] = None 95 96 @classmethod 97 def from_dict(cls, data: Dict[str, Any]) -> "QdrantConfig": 98 """辞書からQdrantConfigを作成 99 100 Args: 101 data: 設定辞書 102 103 Returns: 104 QdrantConfig instance 105 106 Example: 107 >>> config = QdrantConfig.from_dict({ 108 ... "url": "http://localhost:6333", 109 ... "collection_name": "my_collection", 110 ... "vector_size": 1536, 111 ... "distance": "cosine" 112 ... }) 113 """ 114 # payload_indexesがある場合は変換 115 payload_indexes = None 116 if "payload_indexes" in data: 117 payload_indexes = [ 118 PayloadIndexConfig.from_dict(idx) if isinstance(idx, dict) else idx 119 for idx in data["payload_indexes"] 120 ] 121 122 return cls( 123 url=data["url"], 124 collection_name=data["collection_name"], 125 vector_size=data.get("vector_size", 1536), 126 distance=data.get("distance", "cosine"), 127 on_disk_vectors=data.get("on_disk_vectors", False), 128 on_disk_payload=data.get("on_disk_payload", True), 129 hnsw_m=data.get("hnsw_m", 16), 130 hnsw_ef_construct=data.get("hnsw_ef_construct", 256), 131 payload_indexes=payload_indexes if payload_indexes else [ 132 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 133 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 134 ], 135 # auth_token_providerは辞書から設定不可(コールバック関数のため) 136 ) 137 138 @classmethod 139 def from_toml(cls, toml_path: str, section: str = "rag.qdrant") -> "QdrantConfig": 140 """TOMLファイルからQdrantConfigを作成 141 142 Args: 143 toml_path: TOMLファイルパス 144 section: セクション名(ドット区切りでネスト対応) 145 146 Returns: 147 QdrantConfig instance 148 149 Example: 150 >>> # config.toml の [rag.qdrant] セクションを読み込み 151 >>> config = QdrantConfig.from_toml("config.toml") 152 153 >>> # カスタムセクション 154 >>> config = QdrantConfig.from_toml("config.toml", section="vectordb") 155 """ 156 if not TOML_AVAILABLE: 157 raise ImportError("toml library is required: pip install toml") 158 159 with open(toml_path, "r") as f: 160 toml_data = toml.load(f) 161 162 # ネストされたセクションをドット区切りで辿る 163 data = toml_data 164 for key in section.split("."): 165 data = data[key] 166 167 return cls.from_dict(data)
Qdrant設定
Example:
辞書から作成
config = QdrantConfig.from_dict({ ... "url": "http://localhost:6333", ... "collection_name": "my_collection", ... "vector_size": 1536 ... })
>>> # TOMLファイルから作成 >>> config = QdrantConfig.from_toml("config.toml", section="rag.qdrant")
96 @classmethod 97 def from_dict(cls, data: Dict[str, Any]) -> "QdrantConfig": 98 """辞書からQdrantConfigを作成 99 100 Args: 101 data: 設定辞書 102 103 Returns: 104 QdrantConfig instance 105 106 Example: 107 >>> config = QdrantConfig.from_dict({ 108 ... "url": "http://localhost:6333", 109 ... "collection_name": "my_collection", 110 ... "vector_size": 1536, 111 ... "distance": "cosine" 112 ... }) 113 """ 114 # payload_indexesがある場合は変換 115 payload_indexes = None 116 if "payload_indexes" in data: 117 payload_indexes = [ 118 PayloadIndexConfig.from_dict(idx) if isinstance(idx, dict) else idx 119 for idx in data["payload_indexes"] 120 ] 121 122 return cls( 123 url=data["url"], 124 collection_name=data["collection_name"], 125 vector_size=data.get("vector_size", 1536), 126 distance=data.get("distance", "cosine"), 127 on_disk_vectors=data.get("on_disk_vectors", False), 128 on_disk_payload=data.get("on_disk_payload", True), 129 hnsw_m=data.get("hnsw_m", 16), 130 hnsw_ef_construct=data.get("hnsw_ef_construct", 256), 131 payload_indexes=payload_indexes if payload_indexes else [ 132 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 133 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 134 ], 135 # auth_token_providerは辞書から設定不可(コールバック関数のため) 136 )
辞書からQdrantConfigを作成
Args: data: 設定辞書
Returns: QdrantConfig instance
Example:
config = QdrantConfig.from_dict({ ... "url": "http://localhost:6333", ... "collection_name": "my_collection", ... "vector_size": 1536, ... "distance": "cosine" ... })
138 @classmethod 139 def from_toml(cls, toml_path: str, section: str = "rag.qdrant") -> "QdrantConfig": 140 """TOMLファイルからQdrantConfigを作成 141 142 Args: 143 toml_path: TOMLファイルパス 144 section: セクション名(ドット区切りでネスト対応) 145 146 Returns: 147 QdrantConfig instance 148 149 Example: 150 >>> # config.toml の [rag.qdrant] セクションを読み込み 151 >>> config = QdrantConfig.from_toml("config.toml") 152 153 >>> # カスタムセクション 154 >>> config = QdrantConfig.from_toml("config.toml", section="vectordb") 155 """ 156 if not TOML_AVAILABLE: 157 raise ImportError("toml library is required: pip install toml") 158 159 with open(toml_path, "r") as f: 160 toml_data = toml.load(f) 161 162 # ネストされたセクションをドット区切りで辿る 163 data = toml_data 164 for key in section.split("."): 165 data = data[key] 166 167 return cls.from_dict(data)
TOMLファイルからQdrantConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: QdrantConfig instance
Example:
config.toml の [rag.qdrant] セクションを読み込み
config = QdrantConfig.from_toml("config.toml")
>>> # カスタムセクション >>> config = QdrantConfig.from_toml("config.toml", section="vectordb")
12class EventType(str, Enum): 13 """ 14 AgentLoopPublisher粒度のイベントタイプ定義 15 16 フロントエンドでのUI表示制御に使用される主要イベントタイプ 17 """ 18 19 # フェーズ系イベント 20 PHASE_START = "phase_start" 21 PROGRESS_UPDATE = "progress_update" 22 23 # AI思考プロセス 24 THOUGHT_MESSAGE = "thought_message" 25 26 # 完了系イベント 27 COMPLETION_SUCCESS = "completion_success" 28 COMPLETION_FAILURE = "completion_failure" 29 30 # 相互作用イベント 31 USER_INTERACTION_REQUIRED = "user_interaction_required" 32 UNEXPECTED_ERROR = "unexpected_error" 33 34 # HITL(Human-In-The-Loop)関連イベント 35 HITL_REQUIRED_BROWSER_VNC = "hitl_required_browser_vnc" # Pod環境: VNC経由でブラウザ介入 36 HITL_REQUIRED_BROWSER_CLI = "hitl_required_browser_cli" # CLI環境: ローカルブラウザ介入 37 HITL_COMPLETED = "hitl_completed" # HITL操作完了 38 39 # ファイル作成イベント 40 FILE_CREATED = "file_created" 41 42 # パーミッション関連イベント 43 PERMISSION_REQUEST = "permission_request" 44 PERMISSION_RESPONSE = "permission_response" 45 46 # 統一ツールイベント 47 TOOL_START = "tool_start" # ツール実行開始 48 TOOL_RESULT = "tool_result" # ツール実行結果 49 50 # 最終結果 51 FINAL_RESULT = "final_result"
AgentLoopPublisher粒度のイベントタイプ定義
フロントエンドでのUI表示制御に使用される主要イベントタイプ
54class SubEventType(str, Enum): 55 """ 56 ツール結果のサブイベントタイプ定義 57 58 フロントエンドでの表示出し分けに使用 59 例: search_webはWeb検索アイコン、file_editedはファイル編集アイコン 60 """ 61 62 # Web検索関連 63 SEARCH_WEB = "search_web" 64 65 # コマンド実行関連 66 COMMAND_EXECUTION = "command_execution" 67 68 # ファイル操作関連 69 FILE_OPERATION = "file_operation" 70 71 # Claude Code関連 72 CLAUDE_CODE = "claude_code" 73 74 # MCPツール関連 75 MCP_TOOL = "mcp_tool" 76 77 # OpenCode互換ツールイベント 78 FILE_EDITED = "file_edited" # write, edit, multiedit 79 FILE_READ = "file_read" # read 80 FILE_SEARCHED = "file_searched" # grep, glob, ls 81 BASH_EXECUTED = "bash_executed" # bash 82 WEB_FETCHED = "web_fetched" # webfetch 83 TASK_LAUNCHED = "task_launched" # task 84 TODO_UPDATED = "todo_updated" # todo 85 VIDEO_GENERATED = "video_generated" # video_gen 86 87 # Slide generation 88 SLIDE_CREATED = "slide_created" # save_html (slide preview)
ツール結果のサブイベントタイプ定義
フロントエンドでの表示出し分けに使用 例: search_webはWeb検索アイコン、file_editedはファイル編集アイコン
21@dataclass 22class StreamingEvent: 23 """ 24 ストリーミングイベントデータ 25 26 リアルタイムストリーミング(SSE等)で使用される基本イベント構造 27 28 Attributes: 29 event_type: イベントタイプ(EventType enum) 30 execution_id: 実行ID 31 message: イベントメッセージ 32 timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) 33 metadata: 追加メタデータ(オプション) 34 sub_event_type: サブイベントタイプ(フロント表示用、オプション) 35 """ 36 37 event_type: EventType 38 execution_id: str 39 message: str 40 timestamp: float = field(default_factory=_current_timestamp) 41 metadata: Optional[Dict[str, Any]] = None 42 sub_event_type: Optional[SubEventType] = None 43 44 def to_dict(self) -> Dict[str, Any]: 45 """辞書形式に変換""" 46 result = { 47 "event_type": self.event_type.value, 48 "execution_id": self.execution_id, 49 "message": self.message, 50 "timestamp": self.timestamp, 51 } 52 if self.metadata is not None: 53 result["metadata"] = self.metadata 54 if self.sub_event_type is not None: 55 result["sub_event_type"] = self.sub_event_type.value 56 return result
ストリーミングイベントデータ
リアルタイムストリーミング(SSE等)で使用される基本イベント構造
Attributes: event_type: イベントタイプ(EventType enum) execution_id: 実行ID message: イベントメッセージ timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(フロント表示用、オプション)
44 def to_dict(self) -> Dict[str, Any]: 45 """辞書形式に変換""" 46 result = { 47 "event_type": self.event_type.value, 48 "execution_id": self.execution_id, 49 "message": self.message, 50 "timestamp": self.timestamp, 51 } 52 if self.metadata is not None: 53 result["metadata"] = self.metadata 54 if self.sub_event_type is not None: 55 result["sub_event_type"] = self.sub_event_type.value 56 return result
辞書形式に変換
59@dataclass 60class SequencedEvent: 61 """ 62 順序付きイベント 63 64 イベントの順序保証が必要な場合に使用 65 66 Attributes: 67 sequence: シーケンス番号(0から開始、負の値は不可) 68 event_type: イベントタイプ(EventType enum) 69 execution_id: 実行ID 70 message: イベントメッセージ 71 timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) 72 metadata: 追加メタデータ(オプション) 73 sub_event_type: サブイベントタイプ(フロント表示用、オプション) 74 """ 75 76 sequence: int 77 event_type: EventType 78 execution_id: str 79 message: str 80 timestamp: float = field(default_factory=_current_timestamp) 81 metadata: Optional[Dict[str, Any]] = None 82 sub_event_type: Optional[SubEventType] = None 83 84 def __post_init__(self): 85 """シーケンス番号のバリデーション""" 86 if self.sequence < 0: 87 raise ValueError(f"sequence must be >= 0, got {self.sequence}") 88 89 def to_streaming_event(self) -> StreamingEvent: 90 """StreamingEventに変換""" 91 return StreamingEvent( 92 event_type=self.event_type, 93 execution_id=self.execution_id, 94 message=self.message, 95 timestamp=self.timestamp, 96 metadata=self.metadata, 97 sub_event_type=self.sub_event_type, 98 ) 99 100 def to_dict(self) -> Dict[str, Any]: 101 """辞書形式に変換""" 102 result = { 103 "sequence": self.sequence, 104 "event_type": self.event_type.value, 105 "execution_id": self.execution_id, 106 "message": self.message, 107 "timestamp": self.timestamp, 108 } 109 if self.metadata is not None: 110 result["metadata"] = self.metadata 111 if self.sub_event_type is not None: 112 result["sub_event_type"] = self.sub_event_type.value 113 return result
順序付きイベント
イベントの順序保証が必要な場合に使用
Attributes: sequence: シーケンス番号(0から開始、負の値は不可) event_type: イベントタイプ(EventType enum) execution_id: 実行ID message: イベントメッセージ timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(フロント表示用、オプション)
89 def to_streaming_event(self) -> StreamingEvent: 90 """StreamingEventに変換""" 91 return StreamingEvent( 92 event_type=self.event_type, 93 execution_id=self.execution_id, 94 message=self.message, 95 timestamp=self.timestamp, 96 metadata=self.metadata, 97 sub_event_type=self.sub_event_type, 98 )
StreamingEventに変換
100 def to_dict(self) -> Dict[str, Any]: 101 """辞書形式に変換""" 102 result = { 103 "sequence": self.sequence, 104 "event_type": self.event_type.value, 105 "execution_id": self.execution_id, 106 "message": self.message, 107 "timestamp": self.timestamp, 108 } 109 if self.metadata is not None: 110 result["metadata"] = self.metadata 111 if self.sub_event_type is not None: 112 result["sub_event_type"] = self.sub_event_type.value 113 return result
辞書形式に変換
116@dataclass 117class ExecutionMessage: 118 """ 119 実行メッセージデータ(DB保存用) 120 121 エージェント実行中のイベントをデータベースに保存するための構造体。 122 マーケットプレイスUIとの連携に使用。 123 124 Attributes: 125 user_id: ユーザーID 126 conversation_id: 会話ID 127 message_id: メッセージID 128 chunk_type: チャンクタイプ(content, reasoning_content, task_created等) 129 content_data: コンテンツデータ(JSON) 130 131 Example: 132 >>> msg = ExecutionMessage( 133 ... user_id="user-123", 134 ... conversation_id="conv-456", 135 ... message_id="msg-789", 136 ... chunk_type="content", 137 ... content_data={"content": "処理中です..."} 138 ... ) 139 >>> await data_access.insert("execution_messages", msg.to_dict()) 140 """ 141 142 user_id: str 143 conversation_id: str 144 message_id: str 145 chunk_type: str 146 content_data: Dict[str, Any] 147 148 def to_dict(self) -> Dict[str, Any]: 149 """DB挿入用の辞書形式に変換""" 150 return { 151 "user_id": self.user_id, 152 "conversation_id": self.conversation_id, 153 "message_id": self.message_id, 154 "chunk_type": self.chunk_type, 155 "content_data": self.content_data, 156 } 157 158 @classmethod 159 def from_streaming_event( 160 cls, 161 event: "StreamingEvent", 162 user_id: str, 163 conversation_id: str, 164 message_id: str, 165 chunk_type: Optional[str] = None, 166 ) -> "ExecutionMessage": 167 """StreamingEventからExecutionMessageを作成 168 169 Args: 170 event: ストリーミングイベント 171 user_id: ユーザーID 172 conversation_id: 会話ID 173 message_id: メッセージID 174 chunk_type: チャンクタイプ(省略時はevent_typeから推定) 175 176 Returns: 177 ExecutionMessage instance 178 179 Example: 180 >>> msg = ExecutionMessage.from_streaming_event( 181 ... event=event, 182 ... user_id="user-123", 183 ... conversation_id="conv-456", 184 ... message_id="msg-789" 185 ... ) 186 """ 187 # チャンクタイプの推定(省略時) 188 if chunk_type is None: 189 chunk_type_map = { 190 EventType.PHASE_START: "content", 191 EventType.PROGRESS_UPDATE: "content", 192 EventType.THOUGHT_MESSAGE: "reasoning_content", 193 EventType.TOOL_START: "task_created", 194 EventType.TOOL_RESULT: "task_created", 195 EventType.FILE_CREATED: "task_created", 196 EventType.COMPLETION_SUCCESS: "content", 197 EventType.COMPLETION_FAILURE: "content", 198 } 199 chunk_type = chunk_type_map.get(event.event_type, "content") 200 201 content_data = { 202 "content": event.message, 203 } 204 if event.metadata: 205 content_data.update(event.metadata) 206 207 return cls( 208 user_id=user_id, 209 conversation_id=conversation_id, 210 message_id=message_id, 211 chunk_type=chunk_type, 212 content_data=content_data, 213 )
実行メッセージデータ(DB保存用)
エージェント実行中のイベントをデータベースに保存するための構造体。 マーケットプレイスUIとの連携に使用。
Attributes: user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID chunk_type: チャンクタイプ(content, reasoning_content, task_created等) content_data: コンテンツデータ(JSON)
Example:
msg = ExecutionMessage( ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... chunk_type="content", ... content_data={"content": "処理中です..."} ... ) await data_access.insert("execution_messages", msg.to_dict())
148 def to_dict(self) -> Dict[str, Any]: 149 """DB挿入用の辞書形式に変換""" 150 return { 151 "user_id": self.user_id, 152 "conversation_id": self.conversation_id, 153 "message_id": self.message_id, 154 "chunk_type": self.chunk_type, 155 "content_data": self.content_data, 156 }
DB挿入用の辞書形式に変換
158 @classmethod 159 def from_streaming_event( 160 cls, 161 event: "StreamingEvent", 162 user_id: str, 163 conversation_id: str, 164 message_id: str, 165 chunk_type: Optional[str] = None, 166 ) -> "ExecutionMessage": 167 """StreamingEventからExecutionMessageを作成 168 169 Args: 170 event: ストリーミングイベント 171 user_id: ユーザーID 172 conversation_id: 会話ID 173 message_id: メッセージID 174 chunk_type: チャンクタイプ(省略時はevent_typeから推定) 175 176 Returns: 177 ExecutionMessage instance 178 179 Example: 180 >>> msg = ExecutionMessage.from_streaming_event( 181 ... event=event, 182 ... user_id="user-123", 183 ... conversation_id="conv-456", 184 ... message_id="msg-789" 185 ... ) 186 """ 187 # チャンクタイプの推定(省略時) 188 if chunk_type is None: 189 chunk_type_map = { 190 EventType.PHASE_START: "content", 191 EventType.PROGRESS_UPDATE: "content", 192 EventType.THOUGHT_MESSAGE: "reasoning_content", 193 EventType.TOOL_START: "task_created", 194 EventType.TOOL_RESULT: "task_created", 195 EventType.FILE_CREATED: "task_created", 196 EventType.COMPLETION_SUCCESS: "content", 197 EventType.COMPLETION_FAILURE: "content", 198 } 199 chunk_type = chunk_type_map.get(event.event_type, "content") 200 201 content_data = { 202 "content": event.message, 203 } 204 if event.metadata: 205 content_data.update(event.metadata) 206 207 return cls( 208 user_id=user_id, 209 conversation_id=conversation_id, 210 message_id=message_id, 211 chunk_type=chunk_type, 212 content_data=content_data, 213 )
StreamingEventからExecutionMessageを作成
Args: event: ストリーミングイベント user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID chunk_type: チャンクタイプ(省略時はevent_typeから推定)
Returns: ExecutionMessage instance
Example:
msg = ExecutionMessage.from_streaming_event( ... event=event, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789" ... )
41class EventEmitter: 42 """ 43 イベント発行クライアント 44 45 非同期イベントキューを使用してイベントを発行・消費します。 46 シーケンス番号による順序保証を提供します。 47 48 Features: 49 - 非同期イベント発行(emit_event) 50 - シーケンス番号による順序保証 51 - 非同期イベント消費(consume_events) 52 - カスタムイベントハンドラー対応 53 54 Example: 55 >>> # 基本的な使い方 56 >>> emitter = EventEmitter(execution_id="exec-123") 57 >>> 58 >>> # イベント発行 59 >>> await emitter.emit_event( 60 ... EventType.PHASE_START, 61 ... "処理を開始します" 62 ... ) 63 >>> 64 >>> # イベント消費 65 >>> async for chunk in emitter.consume_events(): 66 ... print(chunk) 67 >>> 68 >>> # 完了通知 69 >>> await emitter.emit_event( 70 ... EventType.COMPLETION_SUCCESS, 71 ... "処理が完了しました" 72 ... ) 73 74 Note: 75 - emit_eventとconsume_eventsは非同期で実行してください 76 - 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)発行後、 77 consume_eventsは自動的に終了します 78 """ 79 80 def __init__( 81 self, 82 execution_id: str, 83 handler: Optional[EventHandler] = None, 84 ): 85 """ 86 EventEmitterを初期化 87 88 Args: 89 execution_id: 実行ID(イベントの識別に使用) 90 handler: イベントハンドラー(オプション) 91 指定しない場合、デフォルトのJSON変換が使用されます 92 93 Example: 94 >>> emitter = EventEmitter(execution_id="exec-123") 95 96 >>> # カスタムハンドラー付き 97 >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: 98 ... return f"data: {event.to_dict()}\\n\\n" 99 >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler) 100 """ 101 self.execution_id = execution_id 102 self._handler = handler 103 self._event_queue: asyncio.Queue[SequencedEvent] = asyncio.Queue() 104 self._sequence_number = 0 105 self._is_completed = False 106 self._completion_event = asyncio.Event() 107 108 logger.debug(f"[{execution_id}] EventEmitter initialized") 109 110 @property 111 def is_completed(self) -> bool: 112 """完了状態を取得""" 113 return self._is_completed 114 115 @property 116 def sequence_number(self) -> int: 117 """現在のシーケンス番号を取得""" 118 return self._sequence_number 119 120 async def emit_event( 121 self, 122 event_type: EventType, 123 message: str, 124 metadata: Optional[Dict[str, Any]] = None, 125 sub_event_type: Optional[SubEventType] = None, 126 ) -> None: 127 """ 128 イベントを発行 129 130 Args: 131 event_type: イベントタイプ 132 message: イベントメッセージ 133 metadata: 追加メタデータ(オプション) 134 sub_event_type: サブイベントタイプ(オプション) 135 136 Example: 137 >>> await emitter.emit_event( 138 ... EventType.THOUGHT_MESSAGE, 139 ... "処理中です...", 140 ... metadata={"progress": 50} 141 ... ) 142 143 Note: 144 完了イベント発行後にイベントを発行しようとすると、 145 警告ログが出力され、イベントは無視されます。 146 """ 147 if self._is_completed: 148 logger.warning( 149 f"[{self.execution_id}] Event emission after completion: {event_type}" 150 ) 151 return 152 153 sequenced_event = SequencedEvent( 154 sequence=self._sequence_number, 155 event_type=event_type, 156 execution_id=self.execution_id, 157 message=message, 158 timestamp=time.time(), 159 metadata=metadata, 160 sub_event_type=sub_event_type, 161 ) 162 self._sequence_number += 1 163 164 try: 165 await self._event_queue.put(sequenced_event) 166 logger.info( 167 f"[{self.execution_id}] Event queued: {event_type.value} " 168 f"(seq={sequenced_event.sequence})" 169 ) 170 except Exception as e: 171 logger.error( 172 f"[{self.execution_id}] Failed to queue event {event_type}: {e}" 173 ) 174 raise 175 176 async def emit( 177 self, 178 event_type: EventType, 179 message: str, 180 metadata: Optional[Dict[str, Any]] = None, 181 sub_event_type: Optional[SubEventType] = None, 182 ) -> None: 183 """emit_eventのエイリアス(短縮形)""" 184 await self.emit_event(event_type, message, metadata, sub_event_type) 185 186 async def consume_events( 187 self, 188 timeout: float = 0.1, 189 ) -> AsyncGenerator[str, None]: 190 """ 191 イベントを順序通りに消費してストリーミングチャンクを生成 192 193 Args: 194 timeout: イベント待機タイムアウト(秒) 195 196 Yields: 197 str: ストリーミングチャンク(ハンドラーの戻り値) 198 199 Example: 200 >>> async for chunk in emitter.consume_events(): 201 ... await response.write(chunk) 202 203 Note: 204 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 205 キューが空になった時点で自動的に終了します。 206 """ 207 logger.info(f"[{self.execution_id}] Starting event consumption") 208 209 while True: 210 try: 211 event = await asyncio.wait_for( 212 self._event_queue.get(), 213 timeout=timeout, 214 ) 215 216 logger.debug( 217 f"[{self.execution_id}] Processing event: {event.event_type.value} " 218 f"(seq={event.sequence})" 219 ) 220 221 # イベントをストリーミングチャンクに変換 222 chunk = await self._process_event(event) 223 if chunk: 224 yield chunk 225 226 # 完了イベントを処理したら完了状態に設定 227 if event.event_type in [ 228 EventType.COMPLETION_SUCCESS, 229 EventType.COMPLETION_FAILURE, 230 ]: 231 self._completion_event.set() 232 self._is_completed = True 233 logger.info( 234 f"[{self.execution_id}] Completion event processed" 235 ) 236 237 except asyncio.TimeoutError: 238 # 完了済みでキューが空なら終了 239 if self._completion_event.is_set() and self._event_queue.empty(): 240 logger.info(f"[{self.execution_id}] No more events, terminating") 241 break 242 continue 243 except Exception as e: 244 logger.error(f"[{self.execution_id}] Error consuming events: {e}") 245 break 246 247 logger.info(f"[{self.execution_id}] Event consumption completed") 248 249 async def _process_event(self, event: SequencedEvent) -> Optional[str]: 250 """イベントをストリーミング用に処理""" 251 streaming_event = event.to_streaming_event() 252 253 if self._handler: 254 try: 255 return await self._handler(streaming_event) 256 except Exception as e: 257 logger.error( 258 f"[{self.execution_id}] Error in event handler: {e}" 259 ) 260 return None 261 else: 262 # デフォルト: SSE形式で返す 263 import json 264 return f"data: {json.dumps(streaming_event.to_dict())}\n\n" 265 266 def mark_completed(self) -> None: 267 """ 268 手動で完了状態に設定 269 270 Note: 271 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 272 自動的に完了状態になります。 273 このメソッドは特殊なケースでのみ使用してください。 274 """ 275 self._is_completed = True 276 self._completion_event.set() 277 logger.info(f"[{self.execution_id}] Manually marked as completed") 278 279 async def cleanup(self) -> None: 280 """ 281 リソースのクリーンアップ 282 283 キューに残っているイベントを破棄し、状態をリセットします。 284 285 Example: 286 >>> await emitter.cleanup() 287 """ 288 # キューをクリア 289 while not self._event_queue.empty(): 290 try: 291 self._event_queue.get_nowait() 292 except asyncio.QueueEmpty: 293 break 294 295 self._is_completed = True 296 self._completion_event.set() 297 logger.info(f"[{self.execution_id}] EventEmitter cleaned up")
イベント発行クライアント
非同期イベントキューを使用してイベントを発行・消費します。 シーケンス番号による順序保証を提供します。
Features: - 非同期イベント発行(emit_event) - シーケンス番号による順序保証 - 非同期イベント消費(consume_events) - カスタムイベントハンドラー対応
Example:
基本的な使い方
emitter = EventEmitter(execution_id="exec-123")
イベント発行
await emitter.emit_event( ... EventType.PHASE_START, ... "処理を開始します" ... )
イベント消費
async for chunk in emitter.consume_events(): ... print(chunk)
完了通知
await emitter.emit_event( ... EventType.COMPLETION_SUCCESS, ... "処理が完了しました" ... )
Note: - emit_eventとconsume_eventsは非同期で実行してください - 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)発行後、 consume_eventsは自動的に終了します
80 def __init__( 81 self, 82 execution_id: str, 83 handler: Optional[EventHandler] = None, 84 ): 85 """ 86 EventEmitterを初期化 87 88 Args: 89 execution_id: 実行ID(イベントの識別に使用) 90 handler: イベントハンドラー(オプション) 91 指定しない場合、デフォルトのJSON変換が使用されます 92 93 Example: 94 >>> emitter = EventEmitter(execution_id="exec-123") 95 96 >>> # カスタムハンドラー付き 97 >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: 98 ... return f"data: {event.to_dict()}\\n\\n" 99 >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler) 100 """ 101 self.execution_id = execution_id 102 self._handler = handler 103 self._event_queue: asyncio.Queue[SequencedEvent] = asyncio.Queue() 104 self._sequence_number = 0 105 self._is_completed = False 106 self._completion_event = asyncio.Event() 107 108 logger.debug(f"[{execution_id}] EventEmitter initialized")
EventEmitterを初期化
Args: execution_id: 実行ID(イベントの識別に使用) handler: イベントハンドラー(オプション) 指定しない場合、デフォルトのJSON変換が使用されます
Example:
emitter = EventEmitter(execution_id="exec-123")
>>> # カスタムハンドラー付き >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: ... return f"data: {event.to_dict()}\n\n" >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler)
115 @property 116 def sequence_number(self) -> int: 117 """現在のシーケンス番号を取得""" 118 return self._sequence_number
現在のシーケンス番号を取得
120 async def emit_event( 121 self, 122 event_type: EventType, 123 message: str, 124 metadata: Optional[Dict[str, Any]] = None, 125 sub_event_type: Optional[SubEventType] = None, 126 ) -> None: 127 """ 128 イベントを発行 129 130 Args: 131 event_type: イベントタイプ 132 message: イベントメッセージ 133 metadata: 追加メタデータ(オプション) 134 sub_event_type: サブイベントタイプ(オプション) 135 136 Example: 137 >>> await emitter.emit_event( 138 ... EventType.THOUGHT_MESSAGE, 139 ... "処理中です...", 140 ... metadata={"progress": 50} 141 ... ) 142 143 Note: 144 完了イベント発行後にイベントを発行しようとすると、 145 警告ログが出力され、イベントは無視されます。 146 """ 147 if self._is_completed: 148 logger.warning( 149 f"[{self.execution_id}] Event emission after completion: {event_type}" 150 ) 151 return 152 153 sequenced_event = SequencedEvent( 154 sequence=self._sequence_number, 155 event_type=event_type, 156 execution_id=self.execution_id, 157 message=message, 158 timestamp=time.time(), 159 metadata=metadata, 160 sub_event_type=sub_event_type, 161 ) 162 self._sequence_number += 1 163 164 try: 165 await self._event_queue.put(sequenced_event) 166 logger.info( 167 f"[{self.execution_id}] Event queued: {event_type.value} " 168 f"(seq={sequenced_event.sequence})" 169 ) 170 except Exception as e: 171 logger.error( 172 f"[{self.execution_id}] Failed to queue event {event_type}: {e}" 173 ) 174 raise
イベントを発行
Args: event_type: イベントタイプ message: イベントメッセージ metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(オプション)
Example:
await emitter.emit_event( ... EventType.THOUGHT_MESSAGE, ... "処理中です...", ... metadata={"progress": 50} ... )
Note: 完了イベント発行後にイベントを発行しようとすると、 警告ログが出力され、イベントは無視されます。
176 async def emit( 177 self, 178 event_type: EventType, 179 message: str, 180 metadata: Optional[Dict[str, Any]] = None, 181 sub_event_type: Optional[SubEventType] = None, 182 ) -> None: 183 """emit_eventのエイリアス(短縮形)""" 184 await self.emit_event(event_type, message, metadata, sub_event_type)
emit_eventのエイリアス(短縮形)
186 async def consume_events( 187 self, 188 timeout: float = 0.1, 189 ) -> AsyncGenerator[str, None]: 190 """ 191 イベントを順序通りに消費してストリーミングチャンクを生成 192 193 Args: 194 timeout: イベント待機タイムアウト(秒) 195 196 Yields: 197 str: ストリーミングチャンク(ハンドラーの戻り値) 198 199 Example: 200 >>> async for chunk in emitter.consume_events(): 201 ... await response.write(chunk) 202 203 Note: 204 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 205 キューが空になった時点で自動的に終了します。 206 """ 207 logger.info(f"[{self.execution_id}] Starting event consumption") 208 209 while True: 210 try: 211 event = await asyncio.wait_for( 212 self._event_queue.get(), 213 timeout=timeout, 214 ) 215 216 logger.debug( 217 f"[{self.execution_id}] Processing event: {event.event_type.value} " 218 f"(seq={event.sequence})" 219 ) 220 221 # イベントをストリーミングチャンクに変換 222 chunk = await self._process_event(event) 223 if chunk: 224 yield chunk 225 226 # 完了イベントを処理したら完了状態に設定 227 if event.event_type in [ 228 EventType.COMPLETION_SUCCESS, 229 EventType.COMPLETION_FAILURE, 230 ]: 231 self._completion_event.set() 232 self._is_completed = True 233 logger.info( 234 f"[{self.execution_id}] Completion event processed" 235 ) 236 237 except asyncio.TimeoutError: 238 # 完了済みでキューが空なら終了 239 if self._completion_event.is_set() and self._event_queue.empty(): 240 logger.info(f"[{self.execution_id}] No more events, terminating") 241 break 242 continue 243 except Exception as e: 244 logger.error(f"[{self.execution_id}] Error consuming events: {e}") 245 break 246 247 logger.info(f"[{self.execution_id}] Event consumption completed")
イベントを順序通りに消費してストリーミングチャンクを生成
Args: timeout: イベント待機タイムアウト(秒)
Yields: str: ストリーミングチャンク(ハンドラーの戻り値)
Example:
async for chunk in emitter.consume_events(): ... await response.write(chunk)
Note: 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 キューが空になった時点で自動的に終了します。
266 def mark_completed(self) -> None: 267 """ 268 手動で完了状態に設定 269 270 Note: 271 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 272 自動的に完了状態になります。 273 このメソッドは特殊なケースでのみ使用してください。 274 """ 275 self._is_completed = True 276 self._completion_event.set() 277 logger.info(f"[{self.execution_id}] Manually marked as completed")
手動で完了状態に設定
Note: 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 自動的に完了状態になります。 このメソッドは特殊なケースでのみ使用してください。
279 async def cleanup(self) -> None: 280 """ 281 リソースのクリーンアップ 282 283 キューに残っているイベントを破棄し、状態をリセットします。 284 285 Example: 286 >>> await emitter.cleanup() 287 """ 288 # キューをクリア 289 while not self._event_queue.empty(): 290 try: 291 self._event_queue.get_nowait() 292 except asyncio.QueueEmpty: 293 break 294 295 self._is_completed = True 296 self._completion_event.set() 297 logger.info(f"[{self.execution_id}] EventEmitter cleaned up")
リソースのクリーンアップ
キューに残っているイベントを破棄し、状態をリセットします。
Example:
await emitter.cleanup()
25class EventHandler(Protocol): 26 """イベントハンドラープロトコル 27 28 イベント受信時のコールバック関数の型定義 29 30 Example: 31 >>> async def my_handler(event: StreamingEvent) -> Optional[str]: 32 ... print(f"Received: {event.message}") 33 ... return f"data: {event.to_dict()}" 34 """ 35 36 async def __call__(self, event: StreamingEvent) -> Optional[str]: 37 """イベントを処理してストリーミングチャンクを返す""" 38 ...
イベントハンドラープロトコル
イベント受信時のコールバック関数の型定義
Example:
async def my_handler(event: StreamingEvent) -> Optional[str]: ... print(f"Received: {event.message}") ... return f"data: {event.to_dict()}"
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
300def create_sse_handler() -> EventHandler: 301 """ 302 SSE(Server-Sent Events)形式のイベントハンドラーを作成 303 304 Returns: 305 EventHandler: SSE形式でイベントを変換するハンドラー 306 307 Example: 308 >>> emitter = EventEmitter( 309 ... execution_id="exec-123", 310 ... handler=create_sse_handler() 311 ... ) 312 """ 313 import json 314 315 async def sse_handler(event: StreamingEvent) -> Optional[str]: 316 return f"data: {json.dumps(event.to_dict())}\n\n" 317 318 return sse_handler
SSE(Server-Sent Events)形式のイベントハンドラーを作成
Returns: EventHandler: SSE形式でイベントを変換するハンドラー
Example:
emitter = EventEmitter( ... execution_id="exec-123", ... handler=create_sse_handler() ... )
321def create_json_handler() -> EventHandler: 322 """ 323 JSON形式のイベントハンドラーを作成 324 325 Returns: 326 EventHandler: JSON形式でイベントを変換するハンドラー 327 328 Example: 329 >>> emitter = EventEmitter( 330 ... execution_id="exec-123", 331 ... handler=create_json_handler() 332 ... ) 333 """ 334 import json 335 336 async def json_handler(event: StreamingEvent) -> Optional[str]: 337 return json.dumps(event.to_dict()) 338 339 return json_handler
JSON形式のイベントハンドラーを作成
Returns: EventHandler: JSON形式でイベントを変換するハンドラー
Example:
emitter = EventEmitter( ... execution_id="exec-123", ... handler=create_json_handler() ... )
45class DatabaseEventHandler: 46 """ 47 データベースにイベントを保存するハンドラー 48 49 SDKのDataAccessを使用してexecution_messagesテーブルにイベントを保存。 50 マーケットプレイスUIとの連携に使用。 51 52 Features: 53 - StreamingEventをExecutionMessageに変換 54 - 汎用DataAccessによるDB保存 55 - テーブル名のカスタマイズ対応 56 57 Example: 58 >>> from agenticstar_platform.db import DataAccess, PostgreSQLConfig 59 >>> 60 >>> config = PostgreSQLConfig.from_env() 61 >>> async with DataAccess(config) as data_access: 62 ... handler = DatabaseEventHandler( 63 ... data_access=data_access, 64 ... user_id="user-123", 65 ... conversation_id="conv-456", 66 ... message_id="msg-789", 67 ... ) 68 ... await handler(event) 69 70 Note: 71 テーブルスキーマ(execution_messages): 72 - user_id: VARCHAR 73 - conversation_id: VARCHAR 74 - message_id: VARCHAR 75 - chunk_type: VARCHAR 76 - content_data: JSONB 77 - created_at: TIMESTAMP (自動設定) 78 """ 79 80 def __init__( 81 self, 82 data_access: Any, # DataAccess型だが循環インポート回避のためAny 83 user_id: str, 84 conversation_id: str, 85 message_id: str, 86 table_name: str = "execution_messages", 87 chunk_type_map: Optional[Dict[EventType, str]] = None, 88 ): 89 """ 90 DatabaseEventHandlerを初期化 91 92 Args: 93 data_access: SDKのDataAccessインスタンス 94 user_id: ユーザーID 95 conversation_id: 会話ID 96 message_id: メッセージID 97 table_name: 保存先テーブル名(デフォルト: execution_messages) 98 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 99 100 Example: 101 >>> handler = DatabaseEventHandler( 102 ... data_access=data_access, 103 ... user_id="user-123", 104 ... conversation_id="conv-456", 105 ... message_id="msg-789", 106 ... ) 107 """ 108 self.data_access = data_access 109 self.user_id = user_id 110 self.conversation_id = conversation_id 111 self.message_id = message_id 112 self.table_name = table_name 113 114 # デフォルトのchunk_typeマッピング 115 self._chunk_type_map = chunk_type_map or { 116 EventType.PHASE_START: "content", 117 EventType.PROGRESS_UPDATE: "content", 118 EventType.THOUGHT_MESSAGE: "reasoning_content", 119 EventType.TOOL_START: "task_created", 120 EventType.TOOL_RESULT: "task_created", 121 EventType.FILE_CREATED: "task_created", 122 EventType.USER_INTERACTION_REQUIRED: "content", 123 EventType.COMPLETION_SUCCESS: "content", 124 EventType.COMPLETION_FAILURE: "content", 125 } 126 127 logger.debug( 128 f"DatabaseEventHandler initialized: table={table_name}, " 129 f"user_id={user_id}, conversation_id={conversation_id}" 130 ) 131 132 async def __call__(self, event: StreamingEvent) -> Optional[str]: 133 """ 134 イベントをデータベースに保存 135 136 Args: 137 event: ストリーミングイベント 138 139 Returns: 140 None(DBハンドラーはストリーミングチャンクを返さない) 141 """ 142 try: 143 # chunk_typeを決定 144 chunk_type = self._chunk_type_map.get(event.event_type, "content") 145 146 # ExecutionMessageを作成 147 exec_msg = ExecutionMessage.from_streaming_event( 148 event=event, 149 user_id=self.user_id, 150 conversation_id=self.conversation_id, 151 message_id=self.message_id, 152 chunk_type=chunk_type, 153 ) 154 155 # DataAccessでDB保存 156 result = await self.data_access.insert( 157 self.table_name, 158 exec_msg.to_dict(), 159 ) 160 161 if not result.get("success"): 162 logger.error( 163 f"Failed to save event to database: {result.get('error')}" 164 ) 165 else: 166 logger.debug( 167 f"Event saved to database: {event.event_type.value}" 168 ) 169 170 return None 171 172 except Exception as e: 173 logger.error(f"Error saving event to database: {e}") 174 return None
データベースにイベントを保存するハンドラー
SDKのDataAccessを使用してexecution_messagesテーブルにイベントを保存。 マーケットプレイスUIとの連携に使用。
Features: - StreamingEventをExecutionMessageに変換 - 汎用DataAccessによるDB保存 - テーブル名のカスタマイズ対応
Example:
from agenticstar_platform.db import DataAccess, PostgreSQLConfig
config = PostgreSQLConfig.from_env() async with DataAccess(config) as data_access: ... handler = DatabaseEventHandler( ... data_access=data_access, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... ) ... await handler(event)
Note: テーブルスキーマ(execution_messages): - user_id: VARCHAR - conversation_id: VARCHAR - message_id: VARCHAR - chunk_type: VARCHAR - content_data: JSONB - created_at: TIMESTAMP (自動設定)
80 def __init__( 81 self, 82 data_access: Any, # DataAccess型だが循環インポート回避のためAny 83 user_id: str, 84 conversation_id: str, 85 message_id: str, 86 table_name: str = "execution_messages", 87 chunk_type_map: Optional[Dict[EventType, str]] = None, 88 ): 89 """ 90 DatabaseEventHandlerを初期化 91 92 Args: 93 data_access: SDKのDataAccessインスタンス 94 user_id: ユーザーID 95 conversation_id: 会話ID 96 message_id: メッセージID 97 table_name: 保存先テーブル名(デフォルト: execution_messages) 98 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 99 100 Example: 101 >>> handler = DatabaseEventHandler( 102 ... data_access=data_access, 103 ... user_id="user-123", 104 ... conversation_id="conv-456", 105 ... message_id="msg-789", 106 ... ) 107 """ 108 self.data_access = data_access 109 self.user_id = user_id 110 self.conversation_id = conversation_id 111 self.message_id = message_id 112 self.table_name = table_name 113 114 # デフォルトのchunk_typeマッピング 115 self._chunk_type_map = chunk_type_map or { 116 EventType.PHASE_START: "content", 117 EventType.PROGRESS_UPDATE: "content", 118 EventType.THOUGHT_MESSAGE: "reasoning_content", 119 EventType.TOOL_START: "task_created", 120 EventType.TOOL_RESULT: "task_created", 121 EventType.FILE_CREATED: "task_created", 122 EventType.USER_INTERACTION_REQUIRED: "content", 123 EventType.COMPLETION_SUCCESS: "content", 124 EventType.COMPLETION_FAILURE: "content", 125 } 126 127 logger.debug( 128 f"DatabaseEventHandler initialized: table={table_name}, " 129 f"user_id={user_id}, conversation_id={conversation_id}" 130 )
DatabaseEventHandlerを初期化
Args: data_access: SDKのDataAccessインスタンス user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID table_name: 保存先テーブル名(デフォルト: execution_messages) chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション)
Example:
handler = DatabaseEventHandler( ... data_access=data_access, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... )
177class WebhookEventHandler: 178 """ 179 Webhookにイベントを送信するハンドラー 180 181 HTTP POSTでイベントをWebhookエンドポイントに送信。 182 LibreChat API等との連携に使用。 183 184 Features: 185 - 非同期HTTP POST送信 186 - カスタムヘッダー対応 187 - タイムアウト設定 188 - メッセージフォーマットのカスタマイズ 189 190 Example: 191 >>> handler = WebhookEventHandler( 192 ... webhook_url="http://librechat-api:3081/webhook/notify", 193 ... conversation_id="conv-456", 194 ... message_id="msg-789", 195 ... ) 196 >>> await handler(event) 197 198 Note: 199 送信されるJSON形式: 200 { 201 "type": "append_message", 202 "data": { 203 "conversationId": "...", 204 "messageId": "...", 205 "chunk_type": "content", 206 "content": {"content": "..."}, 207 "finish_reason": null 208 } 209 } 210 """ 211 212 def __init__( 213 self, 214 webhook_url: str, 215 conversation_id: str, 216 message_id: str, 217 headers: Optional[Dict[str, str]] = None, 218 timeout_seconds: int = 30, 219 message_type: str = "append_message", 220 token_provider: Optional[Callable[[], Optional[str]]] = None, 221 chunk_type_map: Optional[Dict[EventType, str]] = None, 222 ): 223 """ 224 WebhookEventHandlerを初期化 225 226 Args: 227 webhook_url: WebhookエンドポイントURL 228 conversation_id: 会話ID 229 message_id: メッセージID 230 headers: カスタムHTTPヘッダー(オプション) 231 timeout_seconds: リクエストタイムアウト秒数 232 message_type: メッセージタイプ(デフォルト: append_message) 233 token_provider: 認証トークン取得関数(オプション) 234 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 235 236 Example: 237 >>> handler = WebhookEventHandler( 238 ... webhook_url="http://librechat-api:3081/webhook/notify", 239 ... conversation_id="conv-456", 240 ... message_id="msg-789", 241 ... headers={"X-Custom-Header": "value"}, 242 ... ) 243 """ 244 self.webhook_url = webhook_url 245 self.conversation_id = conversation_id 246 self.message_id = message_id 247 self.headers = headers or {} 248 self.timeout_seconds = timeout_seconds 249 self.message_type = message_type 250 self.token_provider = token_provider 251 252 # デフォルトのchunk_typeマッピング 253 self._chunk_type_map = chunk_type_map or { 254 EventType.PHASE_START: "content", 255 EventType.PROGRESS_UPDATE: "content", 256 EventType.THOUGHT_MESSAGE: "reasoning_content", 257 EventType.TOOL_START: "task_created", 258 EventType.TOOL_RESULT: "task_created", 259 EventType.FILE_CREATED: "task_created", 260 EventType.USER_INTERACTION_REQUIRED: "content", 261 EventType.COMPLETION_SUCCESS: "content", 262 EventType.COMPLETION_FAILURE: "content", 263 } 264 265 # finish_reasonマッピング 266 self._finish_reason_map = { 267 EventType.COMPLETION_SUCCESS: "stop", 268 EventType.COMPLETION_FAILURE: "error", 269 } 270 271 logger.debug( 272 f"WebhookEventHandler initialized: url={webhook_url}, " 273 f"conversation_id={conversation_id}" 274 ) 275 276 async def __call__(self, event: StreamingEvent) -> Optional[str]: 277 """ 278 イベントをWebhookに送信 279 280 Args: 281 event: ストリーミングイベント 282 283 Returns: 284 None(Webhookハンドラーはストリーミングチャンクを返さない) 285 """ 286 try: 287 # aiohttpをここでimport(オプション依存のため) 288 try: 289 import aiohttp 290 except ImportError: 291 logger.error( 292 "aiohttp is required for WebhookEventHandler: " 293 "pip install aiohttp" 294 ) 295 return None 296 297 # chunk_typeを決定 298 chunk_type = self._chunk_type_map.get(event.event_type, "content") 299 300 # finish_reasonを決定 301 finish_reason = self._finish_reason_map.get(event.event_type) 302 303 # contentを構築 304 content = {"content": event.message} 305 if event.metadata: 306 content.update(event.metadata) 307 308 # Webhookペイロードを構築 309 payload = { 310 "type": self.message_type, 311 "data": { 312 "conversationId": self.conversation_id, 313 "messageId": self.message_id, 314 "chunk_type": chunk_type, 315 "content": content, 316 "finish_reason": finish_reason, 317 }, 318 } 319 320 # ヘッダーを設定 321 headers = {"Content-Type": "application/json"} 322 headers.update(self.headers) 323 324 # トークンプロバイダーがあれば認証ヘッダーを追加 325 if self.token_provider: 326 token = self.token_provider() 327 if token: 328 headers["Authorization"] = f"Bearer {token}" 329 330 # HTTP POSTを送信 331 timeout = aiohttp.ClientTimeout(total=self.timeout_seconds) 332 async with aiohttp.ClientSession(timeout=timeout) as session: 333 async with session.post( 334 self.webhook_url, 335 json=payload, 336 headers=headers, 337 ) as response: 338 # レスポンスを読み取り(リソースリークを防ぐ) 339 await response.text() 340 341 if response.status == 200: 342 logger.debug( 343 f"Webhook notification sent: {event.event_type.value}" 344 ) 345 else: 346 logger.warning( 347 f"Webhook returned status {response.status}" 348 ) 349 350 return None 351 352 except asyncio.TimeoutError: 353 logger.error("Webhook notification timeout") 354 return None 355 except Exception as e: 356 logger.error(f"Error sending webhook notification: {e}") 357 return None
Webhookにイベントを送信するハンドラー
HTTP POSTでイベントをWebhookエンドポイントに送信。 LibreChat API等との連携に使用。
Features: - 非同期HTTP POST送信 - カスタムヘッダー対応 - タイムアウト設定 - メッセージフォーマットのカスタマイズ
Example:
handler = WebhookEventHandler( ... webhook_url="http://librechat-api:3081/webhook/notify", ... conversation_id="conv-456", ... message_id="msg-789", ... ) await handler(event)
Note: 送信されるJSON形式: { "type": "append_message", "data": { "conversationId": "...", "messageId": "...", "chunk_type": "content", "content": {"content": "..."}, "finish_reason": null } }
212 def __init__( 213 self, 214 webhook_url: str, 215 conversation_id: str, 216 message_id: str, 217 headers: Optional[Dict[str, str]] = None, 218 timeout_seconds: int = 30, 219 message_type: str = "append_message", 220 token_provider: Optional[Callable[[], Optional[str]]] = None, 221 chunk_type_map: Optional[Dict[EventType, str]] = None, 222 ): 223 """ 224 WebhookEventHandlerを初期化 225 226 Args: 227 webhook_url: WebhookエンドポイントURL 228 conversation_id: 会話ID 229 message_id: メッセージID 230 headers: カスタムHTTPヘッダー(オプション) 231 timeout_seconds: リクエストタイムアウト秒数 232 message_type: メッセージタイプ(デフォルト: append_message) 233 token_provider: 認証トークン取得関数(オプション) 234 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 235 236 Example: 237 >>> handler = WebhookEventHandler( 238 ... webhook_url="http://librechat-api:3081/webhook/notify", 239 ... conversation_id="conv-456", 240 ... message_id="msg-789", 241 ... headers={"X-Custom-Header": "value"}, 242 ... ) 243 """ 244 self.webhook_url = webhook_url 245 self.conversation_id = conversation_id 246 self.message_id = message_id 247 self.headers = headers or {} 248 self.timeout_seconds = timeout_seconds 249 self.message_type = message_type 250 self.token_provider = token_provider 251 252 # デフォルトのchunk_typeマッピング 253 self._chunk_type_map = chunk_type_map or { 254 EventType.PHASE_START: "content", 255 EventType.PROGRESS_UPDATE: "content", 256 EventType.THOUGHT_MESSAGE: "reasoning_content", 257 EventType.TOOL_START: "task_created", 258 EventType.TOOL_RESULT: "task_created", 259 EventType.FILE_CREATED: "task_created", 260 EventType.USER_INTERACTION_REQUIRED: "content", 261 EventType.COMPLETION_SUCCESS: "content", 262 EventType.COMPLETION_FAILURE: "content", 263 } 264 265 # finish_reasonマッピング 266 self._finish_reason_map = { 267 EventType.COMPLETION_SUCCESS: "stop", 268 EventType.COMPLETION_FAILURE: "error", 269 } 270 271 logger.debug( 272 f"WebhookEventHandler initialized: url={webhook_url}, " 273 f"conversation_id={conversation_id}" 274 )
WebhookEventHandlerを初期化
Args: webhook_url: WebhookエンドポイントURL conversation_id: 会話ID message_id: メッセージID headers: カスタムHTTPヘッダー(オプション) timeout_seconds: リクエストタイムアウト秒数 message_type: メッセージタイプ(デフォルト: append_message) token_provider: 認証トークン取得関数(オプション) chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション)
Example:
handler = WebhookEventHandler( ... webhook_url="http://librechat-api:3081/webhook/notify", ... conversation_id="conv-456", ... message_id="msg-789", ... headers={"X-Custom-Header": "value"}, ... )
360class CompositeEventHandler: 361 """ 362 複数のハンドラーを並列実行する複合ハンドラー 363 364 DatabaseEventHandlerとWebhookEventHandlerを同時実行し、 365 autonomousのPodWebhookSubscriberと同等の機能を提供。 366 367 Features: 368 - 複数ハンドラーの並列実行 369 - エラー時も他のハンドラーは継続 370 - 最初のSSEチャンクを返す(SSEハンドラーがある場合) 371 372 Example: 373 >>> db_handler = DatabaseEventHandler(...) 374 >>> webhook_handler = WebhookEventHandler(...) 375 >>> composite = CompositeEventHandler([db_handler, webhook_handler]) 376 >>> 377 >>> # EventEmitterに設定 378 >>> emitter = EventEmitter( 379 ... execution_id="exec-123", 380 ... handler=composite 381 ... ) 382 383 Note: 384 asyncio.gatherでreturn_exceptions=Trueを使用するため、 385 一部のハンドラーがエラーでも他は正常に実行される。 386 """ 387 388 def __init__( 389 self, 390 handlers: List[Callable[[StreamingEvent], Optional[str]]], 391 return_first_chunk: bool = True, 392 ): 393 """ 394 CompositeEventHandlerを初期化 395 396 Args: 397 handlers: イベントハンドラーのリスト 398 return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True) 399 400 Example: 401 >>> composite = CompositeEventHandler([ 402 ... DatabaseEventHandler(...), 403 ... WebhookEventHandler(...), 404 ... create_sse_handler(), # SSEチャンクを返すハンドラー 405 ... ]) 406 """ 407 self.handlers = handlers 408 self.return_first_chunk = return_first_chunk 409 410 logger.debug( 411 f"CompositeEventHandler initialized with {len(handlers)} handlers" 412 ) 413 414 async def __call__(self, event: StreamingEvent) -> Optional[str]: 415 """ 416 全ハンドラーを並列実行 417 418 Args: 419 event: ストリーミングイベント 420 421 Returns: 422 最初のNon-Noneチャンク(return_first_chunk=Trueの場合) 423 またはNone 424 """ 425 try: 426 # 全ハンドラーを並列実行 427 tasks = [handler(event) for handler in self.handlers] 428 results = await asyncio.gather(*tasks, return_exceptions=True) 429 430 # エラーをログ出力 431 for i, result in enumerate(results): 432 if isinstance(result, Exception): 433 logger.error( 434 f"Handler {i} raised exception: {result}" 435 ) 436 437 # 最初のNon-Noneチャンクを返す 438 if self.return_first_chunk: 439 for result in results: 440 if result is not None and not isinstance(result, Exception): 441 return result 442 443 return None 444 445 except Exception as e: 446 logger.error(f"Error in CompositeEventHandler: {e}") 447 return None
複数のハンドラーを並列実行する複合ハンドラー
DatabaseEventHandlerとWebhookEventHandlerを同時実行し、 autonomousのPodWebhookSubscriberと同等の機能を提供。
Features: - 複数ハンドラーの並列実行 - エラー時も他のハンドラーは継続 - 最初のSSEチャンクを返す(SSEハンドラーがある場合)
Example:
db_handler = DatabaseEventHandler(...) webhook_handler = WebhookEventHandler(...) composite = CompositeEventHandler([db_handler, webhook_handler])
EventEmitterに設定
emitter = EventEmitter( ... execution_id="exec-123", ... handler=composite ... )
Note: asyncio.gatherでreturn_exceptions=Trueを使用するため、 一部のハンドラーがエラーでも他は正常に実行される。
388 def __init__( 389 self, 390 handlers: List[Callable[[StreamingEvent], Optional[str]]], 391 return_first_chunk: bool = True, 392 ): 393 """ 394 CompositeEventHandlerを初期化 395 396 Args: 397 handlers: イベントハンドラーのリスト 398 return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True) 399 400 Example: 401 >>> composite = CompositeEventHandler([ 402 ... DatabaseEventHandler(...), 403 ... WebhookEventHandler(...), 404 ... create_sse_handler(), # SSEチャンクを返すハンドラー 405 ... ]) 406 """ 407 self.handlers = handlers 408 self.return_first_chunk = return_first_chunk 409 410 logger.debug( 411 f"CompositeEventHandler initialized with {len(handlers)} handlers" 412 )
CompositeEventHandlerを初期化
Args: handlers: イベントハンドラーのリスト return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True)
Example:
composite = CompositeEventHandler([ ... DatabaseEventHandler(...), ... WebhookEventHandler(...), ... create_sse_handler(), # SSEチャンクを返すハンドラー ... ])
450def create_marketplace_handler( 451 data_access: Any, 452 webhook_url: str, 453 user_id: str, 454 conversation_id: str, 455 message_id: str, 456 token_provider: Optional[Callable[[], Optional[str]]] = None, 457) -> CompositeEventHandler: 458 """ 459 マーケットプレイス向けの標準ハンドラーを作成 460 461 DB保存 + Webhook送信の複合ハンドラーを簡単に作成するファクトリ関数。 462 463 Args: 464 data_access: SDKのDataAccessインスタンス 465 webhook_url: WebhookエンドポイントURL 466 user_id: ユーザーID 467 conversation_id: 会話ID 468 message_id: メッセージID 469 token_provider: 認証トークン取得関数(オプション) 470 471 Returns: 472 CompositeEventHandler: DB + Webhookの複合ハンドラー 473 474 Example: 475 >>> from agenticstar_platform.db import DataAccess, PostgreSQLConfig 476 >>> from agenticstar_platform.events import EventEmitter 477 >>> from agenticstar_platform.events.handlers import create_marketplace_handler 478 >>> 479 >>> config = PostgreSQLConfig.from_env() 480 >>> async with DataAccess(config) as da: 481 ... handler = create_marketplace_handler( 482 ... data_access=da, 483 ... webhook_url="http://librechat-api:3081/webhook/notify", 484 ... user_id="user-123", 485 ... conversation_id="conv-456", 486 ... message_id="msg-789", 487 ... ) 488 ... emitter = EventEmitter(execution_id="exec-001", handler=handler) 489 ... await emitter.emit(EventType.PHASE_START, "処理開始") 490 """ 491 db_handler = DatabaseEventHandler( 492 data_access=data_access, 493 user_id=user_id, 494 conversation_id=conversation_id, 495 message_id=message_id, 496 ) 497 498 webhook_handler = WebhookEventHandler( 499 webhook_url=webhook_url, 500 conversation_id=conversation_id, 501 message_id=message_id, 502 token_provider=token_provider, 503 ) 504 505 return CompositeEventHandler([db_handler, webhook_handler])
マーケットプレイス向けの標準ハンドラーを作成
DB保存 + Webhook送信の複合ハンドラーを簡単に作成するファクトリ関数。
Args: data_access: SDKのDataAccessインスタンス webhook_url: WebhookエンドポイントURL user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID token_provider: 認証トークン取得関数(オプション)
Returns: CompositeEventHandler: DB + Webhookの複合ハンドラー
Example:
from agenticstar_platform.db import DataAccess, PostgreSQLConfig from agenticstar_platform.events import EventEmitter from agenticstar_platform.events.handlers import create_marketplace_handler
config = PostgreSQLConfig.from_env() async with DataAccess(config) as da: ... handler = create_marketplace_handler( ... data_access=da, ... webhook_url="http://librechat-api:3081/webhook/notify", ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... ) ... emitter = EventEmitter(execution_id="exec-001", handler=handler) ... await emitter.emit(EventType.PHASE_START, "処理開始")
566class SemanticMemoryClient: 567 """ 568 汎用セマンティックメモリクライアント(Mem0ベース) 569 570 会話やコンテキストをセマンティック検索可能な形式で保存・検索します。 571 LLMとEmbedderを使用してメモリの抽出・ベクトル化を行います。 572 573 Features: 574 - マルチプロバイダーLLM/Embedder対応(Azure, OpenAI, Anthropic, Bedrock, Gemini, Groq) 575 - 基本的なメモリ追加・検索(add, search, get_all, delete, delete_all) 576 - カスタムプロンプト対応(ファクト抽出・メモリ更新プロンプトのカスタマイズ) 577 - ベクトルストア設定対応(Qdrant等) 578 579 Note: 580 - 組織階層(personal/role/org/global)は提供しない(利用側で実装) 581 - PIIマスキングは提供しない(利用側で実装) 582 583 Example: 584 >>> from agenticstar_platform import SemanticMemoryClient, SemanticMemoryConfig 585 >>> 586 >>> # 設定を作成 587 >>> config = SemanticMemoryConfig.from_toml("config.toml") 588 >>> 589 >>> # クライアント初期化 590 >>> client = SemanticMemoryClient(config) 591 >>> 592 >>> # メモリを追加 593 >>> client.add( 594 ... messages=[ 595 ... {"role": "user", "content": "私の名前は田中です"}, 596 ... {"role": "assistant", "content": "こんにちは、田中さん"} 597 ... ], 598 ... user_id="user-123" 599 ... ) 600 >>> 601 >>> # メモリを検索 602 >>> results = client.search(query="名前", user_id="user-123") 603 >>> print(results) 604 {'results': [{'memory': 'ユーザーの名前は田中', 'score': 0.95}]} 605 """ 606 607 def _get_api_key(self, config: LLMProviderConfig) -> str: 608 """APIキーを取得(SecretStr対応) 609 610 Args: 611 config: LLMProviderConfig 612 613 Returns: 614 APIキー文字列 615 """ 616 api_key = config.api_key 617 if hasattr(api_key, 'get_secret_value'): 618 return api_key.get_secret_value() 619 return api_key 620 621 def _validate_provider_config(self, config: LLMProviderConfig, context: str) -> None: 622 """プロバイダー固有の設定を検証 623 624 Args: 625 config: LLMProviderConfig 626 context: 検証コンテキスト(エラーメッセージ用) 627 628 Raises: 629 SemanticMemoryConfigError: 必須設定が不足している場合 630 """ 631 provider, _ = self._parse_model_string(config.model) 632 633 if provider == "azure": 634 if not config.base_url: 635 raise SemanticMemoryConfigError( 636 f"{context}: Azure provider requires 'base_url' (azure_endpoint)" 637 ) 638 if not config.api_version: 639 logger.warning( 640 f"{context}: Azure provider 'api_version' not set, using default" 641 ) 642 643 elif provider == "bedrock": 644 missing = [] 645 if not config.aws_access_key_id: 646 missing.append("aws_access_key_id") 647 if not config.aws_secret_access_key: 648 missing.append("aws_secret_access_key") 649 if not config.aws_region_name: 650 missing.append("aws_region_name") 651 if missing: 652 raise SemanticMemoryConfigError( 653 f"{context}: Bedrock provider requires: {', '.join(missing)}" 654 ) 655 656 def __init__(self, config: SemanticMemoryConfig): 657 """ 658 Initialize Semantic Memory Client 659 660 Args: 661 config: SemanticMemoryConfig instance 662 663 Raises: 664 SemanticMemoryConfigError: 設定が不正な場合 665 """ 666 if not config: 667 raise SemanticMemoryConfigError("SemanticMemoryConfig is required") 668 669 if not config.llm_config or not config.embedder_config: 670 raise SemanticMemoryConfigError( 671 "Both llm_config and embedder_config are required" 672 ) 673 674 # Validate provider-specific configurations 675 self._validate_provider_config(config.llm_config, "llm_config") 676 self._validate_provider_config(config.embedder_config, "embedder_config") 677 678 self._config = config 679 self._memory = None 680 self._enabled = False 681 682 try: 683 # Configure Mem0 telemetry before importing 684 _configure_mem0_telemetry() 685 686 from mem0 import Memory 687 688 # Build Mem0 format config 689 mem0_config = { 690 'llm': self._convert_llm_to_mem0(config.llm_config), 691 'embedder': self._convert_embedder_to_mem0(config.embedder_config), 692 } 693 694 if config.vector_store: 695 mem0_config['vector_store'] = config.vector_store 696 697 if config.rerank: 698 mem0_config['rerank'] = config.rerank 699 700 if config.custom_fact_extraction_prompt: 701 mem0_config['custom_fact_extraction_prompt'] = config.custom_fact_extraction_prompt 702 703 if config.custom_update_memory_prompt: 704 mem0_config['custom_update_memory_prompt'] = config.custom_update_memory_prompt 705 706 self._memory = Memory.from_config(mem0_config) 707 self._enabled = True 708 709 logger.info( 710 f"SemanticMemoryClient initialized - " 711 f"llm={config.llm_config.model}, " 712 f"embedder={config.embedder_config.model}" 713 ) 714 715 except ImportError as e: 716 logger.error(f"mem0 library not installed: {e}") 717 raise SemanticMemoryConfigError( 718 "mem0 library is required: pip install mem0ai" 719 ) from e 720 except Exception as e: 721 logger.error(f"Failed to initialize SemanticMemoryClient: {e}") 722 raise SemanticMemoryConfigError( 723 f"Failed to initialize Mem0: {e}" 724 ) from e 725 726 @property 727 def enabled(self) -> bool: 728 """クライアントが有効かどうか""" 729 return self._enabled 730 731 def _normalize_provider(self, provider: str) -> str: 732 """プロバイダー名を正規化(エイリアス対応)""" 733 provider_aliases = { 734 # Azure 735 "azure": "azure", 736 "azure_openai": "azure", 737 "azure-openai": "azure", 738 # Bedrock 739 "bedrock": "bedrock", 740 "aws_bedrock": "bedrock", 741 "aws-bedrock": "bedrock", 742 # Google 743 "gemini": "gemini", 744 "google": "gemini", 745 "google-genai": "gemini", 746 "google_genai": "gemini", 747 } 748 return provider_aliases.get(provider.lower(), provider.lower()) 749 750 def _parse_model_string(self, model: str) -> tuple[str, str]: 751 """モデル文字列をプロバイダーとモデル名に分離 752 753 Args: 754 model: "provider/model-name" or "model-name" format 755 756 Returns: 757 (provider, model_name) tuple 758 """ 759 if "/" in model: 760 provider = model.split("/")[0] 761 model_name = model.split("/", 1)[1] 762 else: 763 provider = "openai" 764 model_name = model 765 766 return self._normalize_provider(provider), model_name 767 768 def _convert_llm_to_mem0(self, llm_config: LLMProviderConfig) -> Dict[str, Any]: 769 """LLMProviderConfigからMem0形式のLLM設定に変換""" 770 provider, model_name = self._parse_model_string(llm_config.model) 771 api_key = self._get_api_key(llm_config) 772 773 if provider == "azure": 774 return { 775 'provider': 'azure_openai', 776 'config': { 777 'azure_kwargs': { 778 'api_key': api_key, 779 'azure_endpoint': llm_config.base_url, 780 'azure_deployment': model_name, 781 'api_version': llm_config.api_version or "2024-02-15-preview" 782 } 783 } 784 } 785 elif provider == "openai": 786 return { 787 'provider': 'openai', 788 'config': { 789 'model': model_name, 790 'api_key': api_key, 791 } 792 } 793 elif provider == "anthropic": 794 return { 795 'provider': 'anthropic', 796 'config': { 797 'model': model_name, 798 'api_key': api_key, 799 } 800 } 801 elif provider == "bedrock": 802 return { 803 'provider': 'aws_bedrock', 804 'config': { 805 'model': model_name, 806 'aws_access_key_id': llm_config.aws_access_key_id, 807 'aws_secret_access_key': llm_config.aws_secret_access_key, 808 'aws_region_name': llm_config.aws_region_name, 809 } 810 } 811 elif provider == "gemini": 812 return { 813 'provider': 'gemini', 814 'config': { 815 'model': model_name, 816 'api_key': api_key, 817 } 818 } 819 elif provider == "groq": 820 return { 821 'provider': 'groq', 822 'config': { 823 'model': model_name, 824 'api_key': api_key, 825 } 826 } 827 else: 828 # Default to provider name with OpenAI-like config 829 return { 830 'provider': provider, 831 'config': { 832 'model': model_name, 833 'api_key': api_key, 834 } 835 } 836 837 def _convert_embedder_to_mem0(self, embedder_config: LLMProviderConfig) -> Dict[str, Any]: 838 """LLMProviderConfigからMem0形式のEmbedder設定に変換""" 839 provider, model_name = self._parse_model_string(embedder_config.model) 840 api_key = self._get_api_key(embedder_config) 841 842 if provider == "azure": 843 return { 844 'provider': 'azure_openai', 845 'config': { 846 'model': model_name, 847 'azure_kwargs': { 848 'api_key': api_key, 849 'azure_endpoint': embedder_config.base_url, 850 'azure_deployment': model_name, 851 'api_version': embedder_config.api_version or "2024-02-15-preview" 852 } 853 } 854 } 855 elif provider == "openai": 856 return { 857 'provider': 'openai', 858 'config': { 859 'model': model_name, 860 'api_key': api_key, 861 } 862 } 863 elif provider == "bedrock": 864 return { 865 'provider': 'aws_bedrock', 866 'config': { 867 'model': model_name, 868 'aws_access_key_id': embedder_config.aws_access_key_id, 869 'aws_secret_access_key': embedder_config.aws_secret_access_key, 870 'aws_region_name': embedder_config.aws_region_name, 871 } 872 } 873 elif provider == "gemini": 874 return { 875 'provider': 'gemini', 876 'config': { 877 'model': model_name, 878 'api_key': api_key, 879 } 880 } 881 else: 882 # Default to provider name with OpenAI-like config 883 return { 884 'provider': provider, 885 'config': { 886 'model': model_name, 887 'api_key': api_key, 888 } 889 } 890 891 def add( 892 self, 893 messages: List[Dict[str, str]], 894 user_id: str, 895 metadata: Optional[Dict[str, Any]] = None, 896 ) -> Dict[str, Any]: 897 """ 898 メモリを追加 899 900 会話メッセージをセマンティックメモリに保存します。 901 保存されたメモリは後でsearch()メソッドで検索できます。 902 903 Args: 904 messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 905 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます 906 metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能 907 908 Returns: 909 Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式 910 911 Raises: 912 SemanticMemoryError: メモリ追加に失敗した場合 913 914 Example: 915 >>> # 会話を保存 916 >>> result = memory_client.add( 917 ... messages=[ 918 ... {"role": "user", "content": "東京の天気は?"}, 919 ... {"role": "assistant", "content": "東京は晴れです。"} 920 ... ], 921 ... user_id="user-123", 922 ... metadata={"topic": "weather"} 923 ... ) 924 >>> print(result) 925 {'results': [...]} 926 """ 927 if not self._enabled: 928 return {"error": "SemanticMemoryClient is not enabled"} 929 930 try: 931 result = self._memory.add( 932 messages=messages, 933 user_id=user_id, 934 metadata=metadata or {}, 935 ) 936 logger.debug(f"Memory added for user_id={user_id}") 937 return result 938 939 except Exception as e: 940 logger.error(f"Failed to add memory: {e}") 941 raise SemanticMemoryError(f"Failed to add memory: {e}") from e 942 943 def search( 944 self, 945 query: str, 946 user_id: str, 947 limit: int = 10, 948 ) -> Dict[str, Any]: 949 """ 950 メモリを検索 951 952 保存されたメモリからセマンティック検索を行います。 953 クエリに意味的に関連するメモリを関連度順に返します。 954 955 Args: 956 query: 検索クエリ。自然言語で記述。例: "天気に関する会話" 957 user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 958 limit: 結果数上限(デフォルト: 10) 959 960 Returns: 961 Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]} 962 963 Raises: 964 SemanticMemoryError: 検索に失敗した場合 965 966 Example: 967 >>> # メモリを検索 968 >>> results = memory_client.search( 969 ... query="天気について教えて", 970 ... user_id="user-123", 971 ... limit=5 972 ... ) 973 >>> for item in results.get("results", []): 974 ... print(f"Score: {item['score']:.2f} - {item['memory']}") 975 """ 976 if not self._enabled: 977 return {"error": "SemanticMemoryClient is not enabled", "results": []} 978 979 try: 980 result = self._memory.search( 981 query=query, 982 user_id=user_id, 983 limit=limit, 984 ) 985 logger.debug( 986 f"Memory searched for user_id={user_id}, " 987 f"results={len(result.get('results', []))}" 988 ) 989 return result 990 991 except Exception as e: 992 logger.error(f"Failed to search memory: {e}") 993 raise SemanticMemoryError(f"Failed to search memory: {e}") from e 994 995 def get_all(self, user_id: str) -> Dict[str, Any]: 996 """ 997 ユーザーの全メモリを取得 998 999 指定されたユーザーに紐づく全てのメモリを取得します。 1000 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。 1001 1002 Args: 1003 user_id: ユーザーID 1004 1005 Returns: 1006 Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]} 1007 1008 Raises: 1009 SemanticMemoryError: 取得に失敗した場合 1010 1011 Example: 1012 >>> # ユーザーの全メモリを取得 1013 >>> all_memories = memory_client.get_all(user_id="user-123") 1014 >>> print(f"Total memories: {len(all_memories.get('results', []))}") 1015 >>> for mem in all_memories.get("results", []): 1016 ... print(f"[{mem['id']}] {mem['memory']}") 1017 """ 1018 if not self._enabled: 1019 return {"error": "SemanticMemoryClient is not enabled", "results": []} 1020 1021 try: 1022 result = self._memory.get_all(user_id=user_id) 1023 logger.debug( 1024 f"All memories retrieved for user_id={user_id}, " 1025 f"count={len(result.get('results', []))}" 1026 ) 1027 return result 1028 1029 except Exception as e: 1030 logger.error(f"Failed to get all memories: {e}") 1031 raise SemanticMemoryError(f"Failed to get all memories: {e}") from e 1032 1033 def delete(self, memory_id: str) -> Dict[str, Any]: 1034 """ 1035 メモリを削除 1036 1037 指定されたIDのメモリを削除します。 1038 メモリIDはget_all()やsearch()で取得できます。 1039 1040 Args: 1041 memory_id: 削除するメモリのID 1042 1043 Returns: 1044 Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"} 1045 1046 Raises: 1047 SemanticMemoryError: 削除に失敗した場合 1048 1049 Example: 1050 >>> # 特定のメモリを削除 1051 >>> result = memory_client.delete(memory_id="mem-abc123") 1052 >>> print(result) 1053 {'message': 'Memory deleted successfully'} 1054 """ 1055 if not self._enabled: 1056 return {"error": "SemanticMemoryClient is not enabled"} 1057 1058 try: 1059 result = self._memory.delete(memory_id=memory_id) 1060 logger.debug(f"Memory deleted: {memory_id}") 1061 return result 1062 1063 except Exception as e: 1064 logger.error(f"Failed to delete memory: {e}") 1065 raise SemanticMemoryError(f"Failed to delete memory: {e}") from e 1066 1067 def delete_all(self, user_id: str) -> Dict[str, Any]: 1068 """ 1069 ユーザーの全メモリを削除 1070 1071 指定されたユーザーに紐づく全てのメモリを削除します。 1072 この操作は取り消せません。注意して使用してください。 1073 1074 Args: 1075 user_id: ユーザーID 1076 1077 Returns: 1078 Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"} 1079 1080 Raises: 1081 SemanticMemoryError: 削除に失敗した場合 1082 1083 Example: 1084 >>> # ユーザーの全メモリを削除 1085 >>> result = memory_client.delete_all(user_id="user-123") 1086 >>> print(result) 1087 {'message': 'All memories deleted successfully'} 1088 1089 Warning: 1090 この操作は取り消せません。全てのメモリが永久に削除されます。 1091 """ 1092 if not self._enabled: 1093 return {"error": "SemanticMemoryClient is not enabled"} 1094 1095 try: 1096 result = self._memory.delete_all(user_id=user_id) 1097 logger.debug(f"All memories deleted for user_id={user_id}") 1098 return result 1099 1100 except Exception as e: 1101 logger.error(f"Failed to delete all memories: {e}") 1102 raise SemanticMemoryError(f"Failed to delete all memories: {e}") from e 1103 1104 async def cleanup(self) -> None: 1105 """ 1106 クライアントリソースのクリーンアップ 1107 1108 非同期コンテキストマネージャーの終了時や、 1109 明示的にリソースを解放する場合に呼び出します。 1110 1111 Example: 1112 >>> # 明示的なクリーンアップ 1113 >>> await memory_client.cleanup() 1114 1115 >>> # コンテキストマネージャー使用時は自動で呼ばれる 1116 >>> async with memory_client: 1117 ... await memory_client.add(...) 1118 ... # ここで自動的にcleanup()が呼ばれる 1119 """ 1120 logger.debug("SemanticMemoryClient cleanup called") 1121 # No specific cleanup needed for mem0 at this time
汎用セマンティックメモリクライアント(Mem0ベース)
会話やコンテキストをセマンティック検索可能な形式で保存・検索します。 LLMとEmbedderを使用してメモリの抽出・ベクトル化を行います。
Features: - マルチプロバイダーLLM/Embedder対応(Azure, OpenAI, Anthropic, Bedrock, Gemini, Groq) - 基本的なメモリ追加・検索(add, search, get_all, delete, delete_all) - カスタムプロンプト対応(ファクト抽出・メモリ更新プロンプトのカスタマイズ) - ベクトルストア設定対応(Qdrant等)
Note: - 組織階層(personal/role/org/global)は提供しない(利用側で実装) - PIIマスキングは提供しない(利用側で実装)
Example:
from agenticstar_platform import SemanticMemoryClient, SemanticMemoryConfig
設定を作成
config = SemanticMemoryConfig.from_toml("config.toml")
クライアント初期化
client = SemanticMemoryClient(config)
メモリを追加
client.add( ... messages=[ ... {"role": "user", "content": "私の名前は田中です"}, ... {"role": "assistant", "content": "こんにちは、田中さん"} ... ], ... user_id="user-123" ... )
メモリを検索
results = client.search(query="名前", user_id="user-123") print(results) {'results': [{'memory': 'ユーザーの名前は田中', 'score': 0.95}]}
656 def __init__(self, config: SemanticMemoryConfig): 657 """ 658 Initialize Semantic Memory Client 659 660 Args: 661 config: SemanticMemoryConfig instance 662 663 Raises: 664 SemanticMemoryConfigError: 設定が不正な場合 665 """ 666 if not config: 667 raise SemanticMemoryConfigError("SemanticMemoryConfig is required") 668 669 if not config.llm_config or not config.embedder_config: 670 raise SemanticMemoryConfigError( 671 "Both llm_config and embedder_config are required" 672 ) 673 674 # Validate provider-specific configurations 675 self._validate_provider_config(config.llm_config, "llm_config") 676 self._validate_provider_config(config.embedder_config, "embedder_config") 677 678 self._config = config 679 self._memory = None 680 self._enabled = False 681 682 try: 683 # Configure Mem0 telemetry before importing 684 _configure_mem0_telemetry() 685 686 from mem0 import Memory 687 688 # Build Mem0 format config 689 mem0_config = { 690 'llm': self._convert_llm_to_mem0(config.llm_config), 691 'embedder': self._convert_embedder_to_mem0(config.embedder_config), 692 } 693 694 if config.vector_store: 695 mem0_config['vector_store'] = config.vector_store 696 697 if config.rerank: 698 mem0_config['rerank'] = config.rerank 699 700 if config.custom_fact_extraction_prompt: 701 mem0_config['custom_fact_extraction_prompt'] = config.custom_fact_extraction_prompt 702 703 if config.custom_update_memory_prompt: 704 mem0_config['custom_update_memory_prompt'] = config.custom_update_memory_prompt 705 706 self._memory = Memory.from_config(mem0_config) 707 self._enabled = True 708 709 logger.info( 710 f"SemanticMemoryClient initialized - " 711 f"llm={config.llm_config.model}, " 712 f"embedder={config.embedder_config.model}" 713 ) 714 715 except ImportError as e: 716 logger.error(f"mem0 library not installed: {e}") 717 raise SemanticMemoryConfigError( 718 "mem0 library is required: pip install mem0ai" 719 ) from e 720 except Exception as e: 721 logger.error(f"Failed to initialize SemanticMemoryClient: {e}") 722 raise SemanticMemoryConfigError( 723 f"Failed to initialize Mem0: {e}" 724 ) from e
Initialize Semantic Memory Client
Args: config: SemanticMemoryConfig instance
Raises: SemanticMemoryConfigError: 設定が不正な場合
891 def add( 892 self, 893 messages: List[Dict[str, str]], 894 user_id: str, 895 metadata: Optional[Dict[str, Any]] = None, 896 ) -> Dict[str, Any]: 897 """ 898 メモリを追加 899 900 会話メッセージをセマンティックメモリに保存します。 901 保存されたメモリは後でsearch()メソッドで検索できます。 902 903 Args: 904 messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 905 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます 906 metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能 907 908 Returns: 909 Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式 910 911 Raises: 912 SemanticMemoryError: メモリ追加に失敗した場合 913 914 Example: 915 >>> # 会話を保存 916 >>> result = memory_client.add( 917 ... messages=[ 918 ... {"role": "user", "content": "東京の天気は?"}, 919 ... {"role": "assistant", "content": "東京は晴れです。"} 920 ... ], 921 ... user_id="user-123", 922 ... metadata={"topic": "weather"} 923 ... ) 924 >>> print(result) 925 {'results': [...]} 926 """ 927 if not self._enabled: 928 return {"error": "SemanticMemoryClient is not enabled"} 929 930 try: 931 result = self._memory.add( 932 messages=messages, 933 user_id=user_id, 934 metadata=metadata or {}, 935 ) 936 logger.debug(f"Memory added for user_id={user_id}") 937 return result 938 939 except Exception as e: 940 logger.error(f"Failed to add memory: {e}") 941 raise SemanticMemoryError(f"Failed to add memory: {e}") from e
メモリを追加
会話メッセージをセマンティックメモリに保存します。 保存されたメモリは後でsearch()メソッドで検索できます。
Args: messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能
Returns: Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式
Raises: SemanticMemoryError: メモリ追加に失敗した場合
Example:
会話を保存
result = memory_client.add( ... messages=[ ... {"role": "user", "content": "東京の天気は?"}, ... {"role": "assistant", "content": "東京は晴れです。"} ... ], ... user_id="user-123", ... metadata={"topic": "weather"} ... ) print(result) {'results': [...]}
943 def search( 944 self, 945 query: str, 946 user_id: str, 947 limit: int = 10, 948 ) -> Dict[str, Any]: 949 """ 950 メモリを検索 951 952 保存されたメモリからセマンティック検索を行います。 953 クエリに意味的に関連するメモリを関連度順に返します。 954 955 Args: 956 query: 検索クエリ。自然言語で記述。例: "天気に関する会話" 957 user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 958 limit: 結果数上限(デフォルト: 10) 959 960 Returns: 961 Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]} 962 963 Raises: 964 SemanticMemoryError: 検索に失敗した場合 965 966 Example: 967 >>> # メモリを検索 968 >>> results = memory_client.search( 969 ... query="天気について教えて", 970 ... user_id="user-123", 971 ... limit=5 972 ... ) 973 >>> for item in results.get("results", []): 974 ... print(f"Score: {item['score']:.2f} - {item['memory']}") 975 """ 976 if not self._enabled: 977 return {"error": "SemanticMemoryClient is not enabled", "results": []} 978 979 try: 980 result = self._memory.search( 981 query=query, 982 user_id=user_id, 983 limit=limit, 984 ) 985 logger.debug( 986 f"Memory searched for user_id={user_id}, " 987 f"results={len(result.get('results', []))}" 988 ) 989 return result 990 991 except Exception as e: 992 logger.error(f"Failed to search memory: {e}") 993 raise SemanticMemoryError(f"Failed to search memory: {e}") from e
メモリを検索
保存されたメモリからセマンティック検索を行います。 クエリに意味的に関連するメモリを関連度順に返します。
Args: query: 検索クエリ。自然言語で記述。例: "天気に関する会話" user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 limit: 結果数上限(デフォルト: 10)
Returns: Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]}
Raises: SemanticMemoryError: 検索に失敗した場合
Example:
メモリを検索
results = memory_client.search( ... query="天気について教えて", ... user_id="user-123", ... limit=5 ... ) for item in results.get("results", []): ... print(f"Score: {item['score']:.2f} - {item['memory']}")
995 def get_all(self, user_id: str) -> Dict[str, Any]: 996 """ 997 ユーザーの全メモリを取得 998 999 指定されたユーザーに紐づく全てのメモリを取得します。 1000 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。 1001 1002 Args: 1003 user_id: ユーザーID 1004 1005 Returns: 1006 Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]} 1007 1008 Raises: 1009 SemanticMemoryError: 取得に失敗した場合 1010 1011 Example: 1012 >>> # ユーザーの全メモリを取得 1013 >>> all_memories = memory_client.get_all(user_id="user-123") 1014 >>> print(f"Total memories: {len(all_memories.get('results', []))}") 1015 >>> for mem in all_memories.get("results", []): 1016 ... print(f"[{mem['id']}] {mem['memory']}") 1017 """ 1018 if not self._enabled: 1019 return {"error": "SemanticMemoryClient is not enabled", "results": []} 1020 1021 try: 1022 result = self._memory.get_all(user_id=user_id) 1023 logger.debug( 1024 f"All memories retrieved for user_id={user_id}, " 1025 f"count={len(result.get('results', []))}" 1026 ) 1027 return result 1028 1029 except Exception as e: 1030 logger.error(f"Failed to get all memories: {e}") 1031 raise SemanticMemoryError(f"Failed to get all memories: {e}") from e
ユーザーの全メモリを取得
指定されたユーザーに紐づく全てのメモリを取得します。 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。
Args: user_id: ユーザーID
Returns: Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]}
Raises: SemanticMemoryError: 取得に失敗した場合
Example:
ユーザーの全メモリを取得
all_memories = memory_client.get_all(user_id="user-123") print(f"Total memories: {len(all_memories.get('results', []))}") for mem in all_memories.get("results", []): ... print(f"[{mem['id']}] {mem['memory']}")
1033 def delete(self, memory_id: str) -> Dict[str, Any]: 1034 """ 1035 メモリを削除 1036 1037 指定されたIDのメモリを削除します。 1038 メモリIDはget_all()やsearch()で取得できます。 1039 1040 Args: 1041 memory_id: 削除するメモリのID 1042 1043 Returns: 1044 Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"} 1045 1046 Raises: 1047 SemanticMemoryError: 削除に失敗した場合 1048 1049 Example: 1050 >>> # 特定のメモリを削除 1051 >>> result = memory_client.delete(memory_id="mem-abc123") 1052 >>> print(result) 1053 {'message': 'Memory deleted successfully'} 1054 """ 1055 if not self._enabled: 1056 return {"error": "SemanticMemoryClient is not enabled"} 1057 1058 try: 1059 result = self._memory.delete(memory_id=memory_id) 1060 logger.debug(f"Memory deleted: {memory_id}") 1061 return result 1062 1063 except Exception as e: 1064 logger.error(f"Failed to delete memory: {e}") 1065 raise SemanticMemoryError(f"Failed to delete memory: {e}") from e
メモリを削除
指定されたIDのメモリを削除します。 メモリIDはget_all()やsearch()で取得できます。
Args: memory_id: 削除するメモリのID
Returns: Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"}
Raises: SemanticMemoryError: 削除に失敗した場合
Example:
特定のメモリを削除
result = memory_client.delete(memory_id="mem-abc123") print(result) {'message': 'Memory deleted successfully'}
1067 def delete_all(self, user_id: str) -> Dict[str, Any]: 1068 """ 1069 ユーザーの全メモリを削除 1070 1071 指定されたユーザーに紐づく全てのメモリを削除します。 1072 この操作は取り消せません。注意して使用してください。 1073 1074 Args: 1075 user_id: ユーザーID 1076 1077 Returns: 1078 Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"} 1079 1080 Raises: 1081 SemanticMemoryError: 削除に失敗した場合 1082 1083 Example: 1084 >>> # ユーザーの全メモリを削除 1085 >>> result = memory_client.delete_all(user_id="user-123") 1086 >>> print(result) 1087 {'message': 'All memories deleted successfully'} 1088 1089 Warning: 1090 この操作は取り消せません。全てのメモリが永久に削除されます。 1091 """ 1092 if not self._enabled: 1093 return {"error": "SemanticMemoryClient is not enabled"} 1094 1095 try: 1096 result = self._memory.delete_all(user_id=user_id) 1097 logger.debug(f"All memories deleted for user_id={user_id}") 1098 return result 1099 1100 except Exception as e: 1101 logger.error(f"Failed to delete all memories: {e}") 1102 raise SemanticMemoryError(f"Failed to delete all memories: {e}") from e
ユーザーの全メモリを削除
指定されたユーザーに紐づく全てのメモリを削除します。 この操作は取り消せません。注意して使用してください。
Args: user_id: ユーザーID
Returns: Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"}
Raises: SemanticMemoryError: 削除に失敗した場合
Example:
ユーザーの全メモリを削除
result = memory_client.delete_all(user_id="user-123") print(result) {'message': 'All memories deleted successfully'}
Warning: この操作は取り消せません。全てのメモリが永久に削除されます。
1104 async def cleanup(self) -> None: 1105 """ 1106 クライアントリソースのクリーンアップ 1107 1108 非同期コンテキストマネージャーの終了時や、 1109 明示的にリソースを解放する場合に呼び出します。 1110 1111 Example: 1112 >>> # 明示的なクリーンアップ 1113 >>> await memory_client.cleanup() 1114 1115 >>> # コンテキストマネージャー使用時は自動で呼ばれる 1116 >>> async with memory_client: 1117 ... await memory_client.add(...) 1118 ... # ここで自動的にcleanup()が呼ばれる 1119 """ 1120 logger.debug("SemanticMemoryClient cleanup called") 1121 # No specific cleanup needed for mem0 at this time
クライアントリソースのクリーンアップ
非同期コンテキストマネージャーの終了時や、 明示的にリソースを解放する場合に呼び出します。
Example:
明示的なクリーンアップ
await memory_client.cleanup()
>>> # コンテキストマネージャー使用時は自動で呼ばれる >>> async with memory_client: ... await memory_client.add(...) ... # ここで自動的にcleanup()が呼ばれる
462@dataclass 463class SemanticMemoryConfig: 464 """Semantic Memory設定 465 466 Example: 467 >>> # 辞書から作成 468 >>> config = SemanticMemoryConfig.from_dict({ 469 ... "llm": { 470 ... "model": "azure/gpt-4", 471 ... "api_key": "key", 472 ... "base_url": "https://...", 473 ... "api_version": "2024-02-15-preview" 474 ... }, 475 ... "embedder": { 476 ... "model": "azure/text-embedding-ada-002", 477 ... "api_key": "key", 478 ... "base_url": "https://...", 479 ... "api_version": "2024-02-15-preview" 480 ... } 481 ... }) 482 483 >>> # TOMLファイルから作成 484 >>> config = SemanticMemoryConfig.from_toml("config.toml", section="memory") 485 486 >>> # Qdrant Vector Store付きで作成 487 >>> config = SemanticMemoryConfig( 488 ... llm_config=llm_config, 489 ... embedder_config=embedder_config, 490 ... vector_store=QdrantVectorStoreConfig( 491 ... url="http://localhost:6333", 492 ... collection_name="my_memories", 493 ... ).to_dict(), 494 ... ) 495 """ 496 llm_config: LLMProviderConfig 497 embedder_config: LLMProviderConfig 498 vector_store: Optional[Dict[str, Any]] = None 499 rerank: Optional[Dict[str, Any]] = None 500 custom_fact_extraction_prompt: Optional[str] = None 501 custom_update_memory_prompt: Optional[str] = None 502 503 @classmethod 504 def from_dict(cls, data: Dict[str, Any]) -> "SemanticMemoryConfig": 505 """辞書からSemanticMemoryConfigを作成 506 507 Args: 508 data: 設定辞書。"llm"と"embedder"キーを含む必要がある 509 510 Returns: 511 SemanticMemoryConfig instance 512 513 Example: 514 >>> config = SemanticMemoryConfig.from_dict({ 515 ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, 516 ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} 517 ... }) 518 """ 519 return cls( 520 llm_config=LLMProviderConfig.from_dict(data["llm"]), 521 embedder_config=LLMProviderConfig.from_dict(data["embedder"]), 522 vector_store=data.get("vector_store"), 523 rerank=data.get("rerank"), 524 custom_fact_extraction_prompt=data.get("custom_fact_extraction_prompt"), 525 custom_update_memory_prompt=data.get("custom_update_memory_prompt"), 526 ) 527 528 @classmethod 529 def from_toml(cls, toml_path: str, section: str = "memory") -> "SemanticMemoryConfig": 530 """TOMLファイルからSemanticMemoryConfigを作成 531 532 Args: 533 toml_path: TOMLファイルパス 534 section: セクション名(ドット区切りでネスト対応) 535 536 Returns: 537 SemanticMemoryConfig instance 538 539 Example: 540 >>> # config.toml の [memory] セクションを読み込み 541 >>> config = SemanticMemoryConfig.from_toml("config.toml") 542 543 >>> # 以下のTOML構造を期待: 544 >>> # [memory.llm] 545 >>> # model = "azure/gpt-4" 546 >>> # api_key = "..." 547 >>> # 548 >>> # [memory.embedder] 549 >>> # model = "azure/text-embedding-ada-002" 550 >>> # api_key = "..." 551 """ 552 if not TOML_AVAILABLE: 553 raise ImportError("toml library is required: pip install toml") 554 555 with open(toml_path, "r") as f: 556 toml_data = toml.load(f) 557 558 # ネストされたセクションをドット区切りで辿る 559 data = toml_data 560 for key in section.split("."): 561 data = data[key] 562 563 return cls.from_dict(data)
Semantic Memory設定
Example:
辞書から作成
config = SemanticMemoryConfig.from_dict({ ... "llm": { ... "model": "azure/gpt-4", ... "api_key": "key", ... "base_url": "https://...", ... "api_version": "2024-02-15-preview" ... }, ... "embedder": { ... "model": "azure/text-embedding-ada-002", ... "api_key": "key", ... "base_url": "https://...", ... "api_version": "2024-02-15-preview" ... } ... })
>>> # TOMLファイルから作成 >>> config = SemanticMemoryConfig.from_toml("config.toml", section="memory") >>> # Qdrant Vector Store付きで作成 >>> config = SemanticMemoryConfig( ... llm_config=llm_config, ... embedder_config=embedder_config, ... vector_store=QdrantVectorStoreConfig( ... url="http://localhost:6333", ... collection_name="my_memories", ... ).to_dict(), ... )
503 @classmethod 504 def from_dict(cls, data: Dict[str, Any]) -> "SemanticMemoryConfig": 505 """辞書からSemanticMemoryConfigを作成 506 507 Args: 508 data: 設定辞書。"llm"と"embedder"キーを含む必要がある 509 510 Returns: 511 SemanticMemoryConfig instance 512 513 Example: 514 >>> config = SemanticMemoryConfig.from_dict({ 515 ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, 516 ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} 517 ... }) 518 """ 519 return cls( 520 llm_config=LLMProviderConfig.from_dict(data["llm"]), 521 embedder_config=LLMProviderConfig.from_dict(data["embedder"]), 522 vector_store=data.get("vector_store"), 523 rerank=data.get("rerank"), 524 custom_fact_extraction_prompt=data.get("custom_fact_extraction_prompt"), 525 custom_update_memory_prompt=data.get("custom_update_memory_prompt"), 526 )
辞書からSemanticMemoryConfigを作成
Args: data: 設定辞書。"llm"と"embedder"キーを含む必要がある
Returns: SemanticMemoryConfig instance
Example:
config = SemanticMemoryConfig.from_dict({ ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} ... })
528 @classmethod 529 def from_toml(cls, toml_path: str, section: str = "memory") -> "SemanticMemoryConfig": 530 """TOMLファイルからSemanticMemoryConfigを作成 531 532 Args: 533 toml_path: TOMLファイルパス 534 section: セクション名(ドット区切りでネスト対応) 535 536 Returns: 537 SemanticMemoryConfig instance 538 539 Example: 540 >>> # config.toml の [memory] セクションを読み込み 541 >>> config = SemanticMemoryConfig.from_toml("config.toml") 542 543 >>> # 以下のTOML構造を期待: 544 >>> # [memory.llm] 545 >>> # model = "azure/gpt-4" 546 >>> # api_key = "..." 547 >>> # 548 >>> # [memory.embedder] 549 >>> # model = "azure/text-embedding-ada-002" 550 >>> # api_key = "..." 551 """ 552 if not TOML_AVAILABLE: 553 raise ImportError("toml library is required: pip install toml") 554 555 with open(toml_path, "r") as f: 556 toml_data = toml.load(f) 557 558 # ネストされたセクションをドット区切りで辿る 559 data = toml_data 560 for key in section.split("."): 561 data = data[key] 562 563 return cls.from_dict(data)
TOMLファイルからSemanticMemoryConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: SemanticMemoryConfig instance
Example:
config.toml の [memory] セクションを読み込み
config = SemanticMemoryConfig.from_toml("config.toml")
>>> # 以下のTOML構造を期待: >>> # [memory.llm] >>> # model = "azure/gpt-4" >>> # api_key = "..." >>> # >>> # [memory.embedder] >>> # model = "azure/text-embedding-ada-002" >>> # api_key = "..."
244class EpisodicMemoryClient: 245 """ 246 汎用エピソディックメモリクライアント(Graphiti/FalkorDBベース) 247 248 Features: 249 - マルチプロバイダーLLM/Embedder/Reranker対応 250 - FalkorDBグラフストレージ 251 - エピソード追加・検索 252 253 Note: 254 - 組織階層(personal/role/org/global)は提供しない 255 - PIIマスキングは提供しない(利用側で実装) 256 - Community機能は提供しない(利用側で必要に応じて実装) 257 258 Warning: 259 - graphiti_coreの制約により、EMBEDDING_DIMはプロセス内でグローバルに設定される 260 - 異なるembedding_modelを使用する複数クライアントを同一プロセスで使用すると、 261 最後に作成されたクライアントのEMBEDDING_DIMが適用される 262 - 異なる次元数のモデルを使用する場合は、別プロセスで実行することを推奨 263 """ 264 265 # Embedding dimensions by model 266 EMBEDDING_DIMS = { 267 # OpenAI models 268 "text-embedding-3-small": 1536, 269 "text-embedding-3-large": 3072, 270 "text-embedding-ada-002": 1536, 271 # Azure variants 272 "text-embedding-3-small-1": 1536, 273 "ada": 1536, 274 # Voyage AI models 275 "voyage-3": 1024, 276 "voyage-3-lite": 512, 277 "voyage-code-3": 1024, 278 "voyage-finance-2": 1024, 279 "voyage-law-2": 1024, 280 "voyage-large-2": 1536, 281 "voyage-2": 1024, 282 # Gemini models 283 "text-embedding-004": 768, 284 "embedding-001": 768, 285 } 286 287 def __init__(self, config: EpisodicMemoryConfig): 288 """ 289 Initialize Episodic Memory Client 290 291 Args: 292 config: EpisodicMemoryConfig instance 293 294 Raises: 295 EpisodicMemoryConfigError: 設定が不正な場合 296 """ 297 if not config: 298 raise EpisodicMemoryConfigError("EpisodicMemoryConfig is required") 299 300 if not config.llm_config or not config.embedder_config: 301 raise EpisodicMemoryConfigError( 302 "Both llm_config and embedder_config are required" 303 ) 304 305 self._config = config 306 self._graphiti = None 307 self._driver = None 308 self._enabled = False 309 self._embedding_dim: Optional[int] = None 310 # Track clients for proper cleanup 311 self._llm_client = None 312 self._embedder = None 313 self._reranker = None 314 self._openai_client = None # For Azure/OpenAI async client 315 316 try: 317 from graphiti_core import Graphiti 318 from graphiti_core.driver.falkordb_driver import FalkorDriver 319 320 # Initialize FalkorDB driver 321 self._driver = FalkorDriver( 322 host=config.host, 323 port=config.port, 324 database=config.database, 325 ) 326 logger.info( 327 f"FalkorDB driver initialized - " 328 f"host={config.host}, port={config.port}, database={config.database}" 329 ) 330 331 # Initialize LLM client 332 self._llm_client, self._openai_client = self._create_llm_client(config.llm_config) 333 334 # Initialize embedder 335 self._embedder = self._create_embedder(config.embedder_config) 336 337 # Initialize reranker 338 self._reranker = self._create_reranker(config.llm_config, self._openai_client) 339 340 # Initialize Graphiti 341 self._graphiti = Graphiti( 342 graph_driver=self._driver, 343 llm_client=self._llm_client, 344 embedder=self._embedder, 345 cross_encoder=self._reranker, 346 ) 347 348 self._enabled = True 349 logger.info( 350 f"EpisodicMemoryClient initialized - " 351 f"llm={config.llm_config.model}, " 352 f"embedder={config.embedder_config.model}" 353 ) 354 355 except ImportError as e: 356 logger.error(f"graphiti-core library not installed: {e}") 357 raise EpisodicMemoryConfigError( 358 "graphiti-core library is required: pip install graphiti-core" 359 ) from e 360 except Exception as e: 361 logger.error(f"Failed to initialize EpisodicMemoryClient: {e}") 362 raise EpisodicMemoryConfigError( 363 f"Failed to initialize Graphiti: {e}" 364 ) from e 365 366 @property 367 def enabled(self) -> bool: 368 """クライアントが有効かどうか""" 369 return self._enabled 370 371 @property 372 def embedding_dim(self) -> Optional[int]: 373 """現在のembedding次元数""" 374 return self._embedding_dim 375 376 @property 377 def graphiti(self): 378 """内部Graphitiインスタンスへのアクセス(高度な操作用)""" 379 return self._graphiti 380 381 def _parse_model_string(self, model: str) -> tuple[str, str]: 382 """モデル文字列をプロバイダーとモデル名に分離""" 383 if "/" in model: 384 provider = model.split("/")[0] 385 model_name = model.split("/", 1)[1] 386 else: 387 provider = "openai" 388 model_name = model 389 return provider.lower(), model_name 390 391 def _get_api_key(self, config: LLMProviderConfig) -> str: 392 """APIキーを取得(SecretStr対応)""" 393 api_key = config.api_key 394 if hasattr(api_key, 'get_secret_value'): 395 return api_key.get_secret_value() 396 return api_key 397 398 def _get_embedding_dim(self, model_name: str) -> int: 399 """モデル名から埋め込み次元数を取得""" 400 # 完全一致 401 if model_name in self.EMBEDDING_DIMS: 402 return self.EMBEDDING_DIMS[model_name] 403 404 # 部分一致 405 model_lower = model_name.lower() 406 for key, dim in self.EMBEDDING_DIMS.items(): 407 if key in model_lower or model_lower in key: 408 return dim 409 410 # デフォルト 411 logger.warning( 412 f"Unknown embedding model '{model_name}', using default dimension 1536" 413 ) 414 return 1536 415 416 def _create_llm_client(self, llm_config: LLMProviderConfig): 417 """LLMクライアントを作成(マルチプロバイダー対応)""" 418 from graphiti_core.llm_client import LLMConfig, OpenAIClient 419 420 provider, model_name = self._parse_model_string(llm_config.model) 421 api_key = self._get_api_key(llm_config) 422 423 graphiti_config = LLMConfig( 424 model=model_name, 425 small_model=model_name, 426 ) 427 428 if provider == "azure": 429 from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient 430 from openai import AsyncAzureOpenAI 431 432 azure_client = AsyncAzureOpenAI( 433 azure_endpoint=llm_config.base_url, 434 api_key=api_key, 435 api_version=llm_config.api_version or "2025-03-01-preview" 436 ) 437 return AzureOpenAILLMClient( 438 azure_client=azure_client, 439 config=graphiti_config 440 ), azure_client 441 442 elif provider == "anthropic": 443 from graphiti_core.llm_client.anthropic_client import AnthropicClient 444 graphiti_config.api_key = api_key 445 client = AnthropicClient(config=graphiti_config) 446 return client, client.client 447 448 elif provider == "openai": 449 graphiti_config.api_key = api_key 450 client = OpenAIClient(config=graphiti_config) 451 return client, client.client 452 453 elif provider == "gemini": 454 from graphiti_core.llm_client.gemini_client import GeminiClient 455 graphiti_config.api_key = api_key 456 client = GeminiClient(config=graphiti_config) 457 return client, client.client 458 459 elif provider == "groq": 460 from graphiti_core.llm_client.groq_client import GroqClient 461 graphiti_config.api_key = api_key 462 client = GroqClient(config=graphiti_config) 463 return client, client.client 464 465 else: 466 raise EpisodicMemoryConfigError(f"Unsupported LLM provider: {provider}") 467 468 def _create_embedder(self, embedder_config: LLMProviderConfig): 469 """Embedderを作成(マルチプロバイダー対応) 470 471 Note: 472 graphiti_coreの制約により、EMBEDDING_DIMはプロセスレベルで設定される。 473 同一プロセスで異なるembedding次元を使用する複数クライアントは非推奨。 474 """ 475 from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig 476 477 provider, model_name = self._parse_model_string(embedder_config.model) 478 api_key = self._get_api_key(embedder_config) 479 480 # Get and store embedding dimension 481 self._embedding_dim = self._get_embedding_dim(model_name) 482 483 # Set embedding dimension environment variable (graphiti_core requirement) 484 # WARNING: This is a global setting that affects all clients in the process 485 os.environ['EMBEDDING_DIM'] = str(self._embedding_dim) 486 # Update graphiti_core constant 487 import graphiti_core.embedder.client as embedder_client 488 embedder_client.EMBEDDING_DIM = self._embedding_dim 489 logger.debug(f"Set EMBEDDING_DIM={self._embedding_dim} for model: {model_name}") 490 491 if provider == "azure": 492 from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient 493 from openai import AsyncAzureOpenAI 494 495 azure_client = AsyncAzureOpenAI( 496 azure_endpoint=embedder_config.base_url, 497 api_key=api_key, 498 api_version=embedder_config.api_version or "2023-05-15" 499 ) 500 return AzureOpenAIEmbedderClient( 501 azure_client=azure_client, 502 model=model_name 503 ) 504 505 elif provider == "openai": 506 config = OpenAIEmbedderConfig( 507 embedding_model=model_name, 508 api_key=api_key, 509 ) 510 return OpenAIEmbedder(config=config) 511 512 elif provider == "gemini": 513 from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig 514 config = GeminiEmbedderConfig( 515 embedding_model=model_name, 516 api_key=api_key, 517 ) 518 return GeminiEmbedder(config=config) 519 520 elif provider == "voyage": 521 from graphiti_core.embedder.voyage import VoyageEmbedder, VoyageEmbedderConfig 522 config = VoyageEmbedderConfig( 523 embedding_model=model_name, 524 api_key=api_key, 525 ) 526 return VoyageEmbedder(config=config) 527 528 else: 529 # Default to OpenAI format 530 config = OpenAIEmbedderConfig( 531 embedding_model=model_name, 532 api_key=api_key, 533 ) 534 return OpenAIEmbedder(config=config) 535 536 def _create_reranker(self, llm_config: LLMProviderConfig, openai_client): 537 """Rerankerを作成(マルチプロバイダー対応)""" 538 from graphiti_core.llm_client import LLMConfig 539 540 provider, model_name = self._parse_model_string(llm_config.model) 541 api_key = self._get_api_key(llm_config) 542 543 if provider in ("azure", "openai"): 544 from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient 545 reranker_config = LLMConfig(model=model_name) 546 return OpenAIRerankerClient(config=reranker_config, client=openai_client) 547 548 elif provider == "gemini": 549 from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient 550 reranker_config = LLMConfig( 551 api_key=api_key, 552 model="gemini-2.5-flash-lite-preview-06-17", 553 ) 554 return GeminiRerankerClient(config=reranker_config) 555 556 else: 557 # Anthropic, Groq等はネイティブrerankerがない 558 logger.warning( 559 f"Provider '{provider}' does not have native reranker support. " 560 "Reranking disabled." 561 ) 562 return None 563 564 # Maximum recommended episode body size (characters) 565 MAX_EPISODE_BODY_SIZE = 100000 566 567 async def add_episode( 568 self, 569 group_id: str, 570 content: str, 571 name: Optional[str] = None, 572 user_id: Optional[str] = None, 573 ) -> bool: 574 """ 575 エピソードを追加 576 577 テキストをエピソードとして追加します。 578 579 Args: 580 group_id: グループID(FalkorDB互換形式に自動正規化) 581 content: エピソード内容(テキスト) 582 name: エピソード名(省略時は自動生成) 583 user_id: ユーザーID(ソース記述に使用) 584 585 Returns: 586 成功した場合True 587 588 Raises: 589 EpisodicMemoryError: 追加に失敗した場合 590 591 Example: 592 >>> await client.add_episode( 593 ... group_id="user-123", 594 ... content="ユーザーはPythonに興味があります", 595 ... name="user_interest", 596 ... ) 597 598 Note: 599 episode_bodyが100,000文字を超える場合は警告を出力します。 600 大きなエピソードはパフォーマンスに影響する可能性があります。 601 """ 602 if not self._enabled: 603 return False 604 605 if not content: 606 raise EpisodicMemoryError("'content' is required") 607 608 try: 609 from graphiti_core.nodes import EpisodeType 610 611 # Normalize group_id 612 normalized_group_id = normalize_group_id(group_id) 613 614 # Warn if episode body is too large 615 body_size = len(content) 616 if body_size > self.MAX_EPISODE_BODY_SIZE: 617 logger.warning( 618 f"Episode body is large ({body_size} chars, recommended max: " 619 f"{self.MAX_EPISODE_BODY_SIZE}). Consider splitting into smaller episodes." 620 ) 621 622 # エピソード名を決定 623 actual_name = name 624 if not actual_name: 625 actual_name = content[:50] + "..." if len(content) > 50 else content 626 627 # Add episode to Graphiti 628 # NOTE: update_communities=False to avoid graphiti-core bugs 629 await self._graphiti.add_episode( 630 group_id=normalized_group_id, 631 name=actual_name, 632 episode_body=content, 633 source=EpisodeType.message, 634 source_description=f"sdk-client-{user_id}" if user_id else "sdk-client", 635 reference_time=datetime.now(timezone.utc), 636 update_communities=False, 637 ) 638 639 logger.debug( 640 f"Episode added - group_id={normalized_group_id}" 641 ) 642 return True 643 644 except Exception as e: 645 logger.error(f"Failed to add episode: {e}") 646 raise EpisodicMemoryError(f"Failed to add episode: {e}") from e 647 648 async def search( 649 self, 650 query: str, 651 group_ids: Optional[List[str]] = None, 652 max_facts: int = 10, 653 ) -> Dict[str, Any]: 654 """ 655 エピソード記憶を検索 656 657 Args: 658 query: 検索クエリ 659 group_ids: 検索対象グループIDリスト(省略時は全グループ) 660 max_facts: 結果数上限(デフォルト: 10) 661 662 Returns: 663 Dict[str, Any]: 検索結果 664 - success: bool 665 - data: { 666 "query": str, 667 "results": [{"fact": str, "created_at": str, "uuid": str}], 668 "total_found": int 669 } 670 - error: str, error_code: str (失敗時) 671 672 Example: 673 >>> result = await client.search( 674 ... query="Pythonに関する情報", 675 ... group_ids=["user-123"], 676 ... max_facts=5, 677 ... ) 678 >>> if result["success"]: 679 ... for item in result["data"]["results"]: 680 ... print(item["fact"]) 681 682 Raises: 683 EpisodicMemoryError: 検索に失敗した場合 684 """ 685 if not self._enabled: 686 return { 687 "success": False, 688 "error": "Episodic memory is not enabled", 689 "error_code": "DISABLED", 690 } 691 692 try: 693 # Normalize group_ids 694 normalized_group_ids = None 695 if group_ids: 696 normalized_group_ids = [normalize_group_id(gid) for gid in group_ids] 697 698 # Search 699 edges = await self._graphiti.search( 700 query=query, 701 group_ids=normalized_group_ids, 702 num_results=max_facts, 703 ) 704 705 # Convert to fact dicts 706 facts = [] 707 for edge in edges: 708 fact = { 709 "fact": edge.fact if hasattr(edge, 'fact') else str(edge), 710 "created_at": edge.created_at.isoformat() if hasattr(edge, 'created_at') else None, 711 "uuid": edge.uuid if hasattr(edge, 'uuid') else None, 712 } 713 facts.append(fact) 714 715 logger.debug(f"Search returned {len(facts)} facts") 716 return { 717 "success": True, 718 "data": { 719 "query": query, 720 "results": facts, 721 "total_found": len(facts), 722 }, 723 } 724 725 except Exception as e: 726 logger.error(f"Failed to search: {e}") 727 return { 728 "success": False, 729 "error": f"Search error: {str(e)}", 730 "error_code": "EPISODIC_SEARCH_ERROR", 731 } 732 733 async def get_recent_episodes( 734 self, 735 group_id: str, 736 last_n: int = 5, 737 ) -> List[Any]: 738 """ 739 最近のエピソードを取得 740 741 Args: 742 group_id: グループID 743 last_n: 取得するエピソード数 744 745 Returns: 746 エピソードリスト 747 """ 748 if not self._enabled: 749 return [] 750 751 try: 752 normalized_group_id = normalize_group_id(group_id) 753 754 episodes = await self._graphiti.retrieve_episodes( 755 group_ids=[normalized_group_id], 756 last_n=last_n, 757 reference_time=datetime.now(timezone.utc), 758 ) 759 760 logger.debug( 761 f"Retrieved {len(episodes)} recent episodes for group_id={group_id}" 762 ) 763 return episodes 764 765 except Exception as e: 766 logger.error(f"Failed to get recent episodes: {e}") 767 raise EpisodicMemoryError( 768 f"Failed to get recent episodes: {e}" 769 ) from e 770 771 def format_facts_for_context( 772 self, 773 facts: List[Dict[str, Any]], 774 max_facts: int = 5, 775 title: str = "### Episodic Memory", 776 ) -> str: 777 """ 778 ファクトをLLMコンテキスト用にフォーマット 779 780 Args: 781 facts: ファクトリスト 782 max_facts: 最大ファクト数 783 title: セクションタイトル 784 785 Returns: 786 フォーマットされた文字列 787 """ 788 if not facts: 789 return "" 790 791 lines = [title] 792 793 for i, fact in enumerate(facts[:max_facts], 1): 794 fact_text = fact.get("fact", "") 795 created_at = fact.get("created_at") 796 797 if created_at: 798 try: 799 dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) 800 time_str = dt.strftime("%Y-%m-%d %H:%M") 801 except: 802 time_str = created_at 803 lines.append(f"{i}. [{time_str}] {fact_text}") 804 else: 805 lines.append(f"{i}. {fact_text}") 806 807 return "\n".join(lines) 808 809 async def close(self) -> None: 810 """コネクションとクライアントリソースをクローズ 811 812 Note: 813 FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを 814 すべてクローズします。 815 """ 816 errors = [] 817 818 # Close FalkorDB driver 819 try: 820 if self._driver and hasattr(self._driver, 'close'): 821 await self._driver.close() 822 self._driver = None 823 except Exception as e: 824 errors.append(f"driver: {e}") 825 826 # Close OpenAI/Azure async client 827 try: 828 if self._openai_client and hasattr(self._openai_client, 'close'): 829 await self._openai_client.close() 830 self._openai_client = None 831 except Exception as e: 832 errors.append(f"openai_client: {e}") 833 834 # Note: LLM client, embedder, reranker typically don't have close methods 835 # but we clear references to allow garbage collection 836 self._llm_client = None 837 self._embedder = None 838 self._reranker = None 839 self._graphiti = None 840 self._enabled = False 841 842 if errors: 843 logger.warning(f"Errors during close: {', '.join(errors)}") 844 else: 845 logger.debug("EpisodicMemoryClient closed successfully")
汎用エピソディックメモリクライアント(Graphiti/FalkorDBベース)
Features:
- マルチプロバイダーLLM/Embedder/Reranker対応
- FalkorDBグラフストレージ
- エピソード追加・検索
Note:
- 組織階層(personal/role/org/global)は提供しない
- PIIマスキングは提供しない(利用側で実装)
- Community機能は提供しない(利用側で必要に応じて実装)
Warning:
- graphiti_coreの制約により、EMBEDDING_DIMはプロセス内でグローバルに設定される
- 異なるembedding_modelを使用する複数クライアントを同一プロセスで使用すると、 最後に作成されたクライアントのEMBEDDING_DIMが適用される
- 異なる次元数のモデルを使用する場合は、別プロセスで実行することを推奨
287 def __init__(self, config: EpisodicMemoryConfig): 288 """ 289 Initialize Episodic Memory Client 290 291 Args: 292 config: EpisodicMemoryConfig instance 293 294 Raises: 295 EpisodicMemoryConfigError: 設定が不正な場合 296 """ 297 if not config: 298 raise EpisodicMemoryConfigError("EpisodicMemoryConfig is required") 299 300 if not config.llm_config or not config.embedder_config: 301 raise EpisodicMemoryConfigError( 302 "Both llm_config and embedder_config are required" 303 ) 304 305 self._config = config 306 self._graphiti = None 307 self._driver = None 308 self._enabled = False 309 self._embedding_dim: Optional[int] = None 310 # Track clients for proper cleanup 311 self._llm_client = None 312 self._embedder = None 313 self._reranker = None 314 self._openai_client = None # For Azure/OpenAI async client 315 316 try: 317 from graphiti_core import Graphiti 318 from graphiti_core.driver.falkordb_driver import FalkorDriver 319 320 # Initialize FalkorDB driver 321 self._driver = FalkorDriver( 322 host=config.host, 323 port=config.port, 324 database=config.database, 325 ) 326 logger.info( 327 f"FalkorDB driver initialized - " 328 f"host={config.host}, port={config.port}, database={config.database}" 329 ) 330 331 # Initialize LLM client 332 self._llm_client, self._openai_client = self._create_llm_client(config.llm_config) 333 334 # Initialize embedder 335 self._embedder = self._create_embedder(config.embedder_config) 336 337 # Initialize reranker 338 self._reranker = self._create_reranker(config.llm_config, self._openai_client) 339 340 # Initialize Graphiti 341 self._graphiti = Graphiti( 342 graph_driver=self._driver, 343 llm_client=self._llm_client, 344 embedder=self._embedder, 345 cross_encoder=self._reranker, 346 ) 347 348 self._enabled = True 349 logger.info( 350 f"EpisodicMemoryClient initialized - " 351 f"llm={config.llm_config.model}, " 352 f"embedder={config.embedder_config.model}" 353 ) 354 355 except ImportError as e: 356 logger.error(f"graphiti-core library not installed: {e}") 357 raise EpisodicMemoryConfigError( 358 "graphiti-core library is required: pip install graphiti-core" 359 ) from e 360 except Exception as e: 361 logger.error(f"Failed to initialize EpisodicMemoryClient: {e}") 362 raise EpisodicMemoryConfigError( 363 f"Failed to initialize Graphiti: {e}" 364 ) from e
Initialize Episodic Memory Client
Args: config: EpisodicMemoryConfig instance
Raises: EpisodicMemoryConfigError: 設定が不正な場合
371 @property 372 def embedding_dim(self) -> Optional[int]: 373 """現在のembedding次元数""" 374 return self._embedding_dim
現在のembedding次元数
376 @property 377 def graphiti(self): 378 """内部Graphitiインスタンスへのアクセス(高度な操作用)""" 379 return self._graphiti
内部Graphitiインスタンスへのアクセス(高度な操作用)
567 async def add_episode( 568 self, 569 group_id: str, 570 content: str, 571 name: Optional[str] = None, 572 user_id: Optional[str] = None, 573 ) -> bool: 574 """ 575 エピソードを追加 576 577 テキストをエピソードとして追加します。 578 579 Args: 580 group_id: グループID(FalkorDB互換形式に自動正規化) 581 content: エピソード内容(テキスト) 582 name: エピソード名(省略時は自動生成) 583 user_id: ユーザーID(ソース記述に使用) 584 585 Returns: 586 成功した場合True 587 588 Raises: 589 EpisodicMemoryError: 追加に失敗した場合 590 591 Example: 592 >>> await client.add_episode( 593 ... group_id="user-123", 594 ... content="ユーザーはPythonに興味があります", 595 ... name="user_interest", 596 ... ) 597 598 Note: 599 episode_bodyが100,000文字を超える場合は警告を出力します。 600 大きなエピソードはパフォーマンスに影響する可能性があります。 601 """ 602 if not self._enabled: 603 return False 604 605 if not content: 606 raise EpisodicMemoryError("'content' is required") 607 608 try: 609 from graphiti_core.nodes import EpisodeType 610 611 # Normalize group_id 612 normalized_group_id = normalize_group_id(group_id) 613 614 # Warn if episode body is too large 615 body_size = len(content) 616 if body_size > self.MAX_EPISODE_BODY_SIZE: 617 logger.warning( 618 f"Episode body is large ({body_size} chars, recommended max: " 619 f"{self.MAX_EPISODE_BODY_SIZE}). Consider splitting into smaller episodes." 620 ) 621 622 # エピソード名を決定 623 actual_name = name 624 if not actual_name: 625 actual_name = content[:50] + "..." if len(content) > 50 else content 626 627 # Add episode to Graphiti 628 # NOTE: update_communities=False to avoid graphiti-core bugs 629 await self._graphiti.add_episode( 630 group_id=normalized_group_id, 631 name=actual_name, 632 episode_body=content, 633 source=EpisodeType.message, 634 source_description=f"sdk-client-{user_id}" if user_id else "sdk-client", 635 reference_time=datetime.now(timezone.utc), 636 update_communities=False, 637 ) 638 639 logger.debug( 640 f"Episode added - group_id={normalized_group_id}" 641 ) 642 return True 643 644 except Exception as e: 645 logger.error(f"Failed to add episode: {e}") 646 raise EpisodicMemoryError(f"Failed to add episode: {e}") from e
エピソードを追加
テキストをエピソードとして追加します。
Args: group_id: グループID(FalkorDB互換形式に自動正規化) content: エピソード内容(テキスト) name: エピソード名(省略時は自動生成) user_id: ユーザーID(ソース記述に使用)
Returns: 成功した場合True
Raises: EpisodicMemoryError: 追加に失敗した場合
Example:
await client.add_episode( ... group_id="user-123", ... content="ユーザーはPythonに興味があります", ... name="user_interest", ... )
Note: episode_bodyが100,000文字を超える場合は警告を出力します。 大きなエピソードはパフォーマンスに影響する可能性があります。
648 async def search( 649 self, 650 query: str, 651 group_ids: Optional[List[str]] = None, 652 max_facts: int = 10, 653 ) -> Dict[str, Any]: 654 """ 655 エピソード記憶を検索 656 657 Args: 658 query: 検索クエリ 659 group_ids: 検索対象グループIDリスト(省略時は全グループ) 660 max_facts: 結果数上限(デフォルト: 10) 661 662 Returns: 663 Dict[str, Any]: 検索結果 664 - success: bool 665 - data: { 666 "query": str, 667 "results": [{"fact": str, "created_at": str, "uuid": str}], 668 "total_found": int 669 } 670 - error: str, error_code: str (失敗時) 671 672 Example: 673 >>> result = await client.search( 674 ... query="Pythonに関する情報", 675 ... group_ids=["user-123"], 676 ... max_facts=5, 677 ... ) 678 >>> if result["success"]: 679 ... for item in result["data"]["results"]: 680 ... print(item["fact"]) 681 682 Raises: 683 EpisodicMemoryError: 検索に失敗した場合 684 """ 685 if not self._enabled: 686 return { 687 "success": False, 688 "error": "Episodic memory is not enabled", 689 "error_code": "DISABLED", 690 } 691 692 try: 693 # Normalize group_ids 694 normalized_group_ids = None 695 if group_ids: 696 normalized_group_ids = [normalize_group_id(gid) for gid in group_ids] 697 698 # Search 699 edges = await self._graphiti.search( 700 query=query, 701 group_ids=normalized_group_ids, 702 num_results=max_facts, 703 ) 704 705 # Convert to fact dicts 706 facts = [] 707 for edge in edges: 708 fact = { 709 "fact": edge.fact if hasattr(edge, 'fact') else str(edge), 710 "created_at": edge.created_at.isoformat() if hasattr(edge, 'created_at') else None, 711 "uuid": edge.uuid if hasattr(edge, 'uuid') else None, 712 } 713 facts.append(fact) 714 715 logger.debug(f"Search returned {len(facts)} facts") 716 return { 717 "success": True, 718 "data": { 719 "query": query, 720 "results": facts, 721 "total_found": len(facts), 722 }, 723 } 724 725 except Exception as e: 726 logger.error(f"Failed to search: {e}") 727 return { 728 "success": False, 729 "error": f"Search error: {str(e)}", 730 "error_code": "EPISODIC_SEARCH_ERROR", 731 }
エピソード記憶を検索
Args: query: 検索クエリ group_ids: 検索対象グループIDリスト(省略時は全グループ) max_facts: 結果数上限(デフォルト: 10)
Returns: Dict[str, Any]: 検索結果 - success: bool - data: { "query": str, "results": [{"fact": str, "created_at": str, "uuid": str}], "total_found": int } - error: str, error_code: str (失敗時)
Example:
result = await client.search( ... query="Pythonに関する情報", ... group_ids=["user-123"], ... max_facts=5, ... ) if result["success"]: ... for item in result["data"]["results"]: ... print(item["fact"])
Raises: EpisodicMemoryError: 検索に失敗した場合
733 async def get_recent_episodes( 734 self, 735 group_id: str, 736 last_n: int = 5, 737 ) -> List[Any]: 738 """ 739 最近のエピソードを取得 740 741 Args: 742 group_id: グループID 743 last_n: 取得するエピソード数 744 745 Returns: 746 エピソードリスト 747 """ 748 if not self._enabled: 749 return [] 750 751 try: 752 normalized_group_id = normalize_group_id(group_id) 753 754 episodes = await self._graphiti.retrieve_episodes( 755 group_ids=[normalized_group_id], 756 last_n=last_n, 757 reference_time=datetime.now(timezone.utc), 758 ) 759 760 logger.debug( 761 f"Retrieved {len(episodes)} recent episodes for group_id={group_id}" 762 ) 763 return episodes 764 765 except Exception as e: 766 logger.error(f"Failed to get recent episodes: {e}") 767 raise EpisodicMemoryError( 768 f"Failed to get recent episodes: {e}" 769 ) from e
最近のエピソードを取得
Args: group_id: グループID last_n: 取得するエピソード数
Returns: エピソードリスト
771 def format_facts_for_context( 772 self, 773 facts: List[Dict[str, Any]], 774 max_facts: int = 5, 775 title: str = "### Episodic Memory", 776 ) -> str: 777 """ 778 ファクトをLLMコンテキスト用にフォーマット 779 780 Args: 781 facts: ファクトリスト 782 max_facts: 最大ファクト数 783 title: セクションタイトル 784 785 Returns: 786 フォーマットされた文字列 787 """ 788 if not facts: 789 return "" 790 791 lines = [title] 792 793 for i, fact in enumerate(facts[:max_facts], 1): 794 fact_text = fact.get("fact", "") 795 created_at = fact.get("created_at") 796 797 if created_at: 798 try: 799 dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) 800 time_str = dt.strftime("%Y-%m-%d %H:%M") 801 except: 802 time_str = created_at 803 lines.append(f"{i}. [{time_str}] {fact_text}") 804 else: 805 lines.append(f"{i}. {fact_text}") 806 807 return "\n".join(lines)
ファクトをLLMコンテキスト用にフォーマット
Args: facts: ファクトリスト max_facts: 最大ファクト数 title: セクションタイトル
Returns: フォーマットされた文字列
809 async def close(self) -> None: 810 """コネクションとクライアントリソースをクローズ 811 812 Note: 813 FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを 814 すべてクローズします。 815 """ 816 errors = [] 817 818 # Close FalkorDB driver 819 try: 820 if self._driver and hasattr(self._driver, 'close'): 821 await self._driver.close() 822 self._driver = None 823 except Exception as e: 824 errors.append(f"driver: {e}") 825 826 # Close OpenAI/Azure async client 827 try: 828 if self._openai_client and hasattr(self._openai_client, 'close'): 829 await self._openai_client.close() 830 self._openai_client = None 831 except Exception as e: 832 errors.append(f"openai_client: {e}") 833 834 # Note: LLM client, embedder, reranker typically don't have close methods 835 # but we clear references to allow garbage collection 836 self._llm_client = None 837 self._embedder = None 838 self._reranker = None 839 self._graphiti = None 840 self._enabled = False 841 842 if errors: 843 logger.warning(f"Errors during close: {', '.join(errors)}") 844 else: 845 logger.debug("EpisodicMemoryClient closed successfully")
コネクションとクライアントリソースをクローズ
Note: FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを すべてクローズします。
153@dataclass 154class EpisodicMemoryConfig: 155 """Episodic Memory設定 156 157 FalkorDB/Graphitiベースのエピソード記憶システムの設定。 158 159 Attributes: 160 llm_config: LLMプロバイダー設定 161 embedder_config: Embedderプロバイダー設定 162 host: FalkorDBホスト(デフォルト: localhost) 163 port: FalkorDBポート(デフォルト: 6379) 164 database: データベース名(デフォルト: graphiti) 165 166 Example: 167 >>> from agenticstar_platform.memory import ( 168 ... EpisodicMemoryConfig, EpisodicMemoryClient, LLMProviderConfig 169 ... ) 170 >>> 171 >>> config = EpisodicMemoryConfig( 172 ... llm_config=LLMProviderConfig( 173 ... model="azure/gpt-4", 174 ... api_key="your-key", 175 ... base_url="https://your-resource.openai.azure.com", 176 ... ), 177 ... embedder_config=LLMProviderConfig( 178 ... model="azure/text-embedding-3-small", 179 ... api_key="your-key", 180 ... base_url="https://your-resource.openai.azure.com", 181 ... ), 182 ... host="localhost", 183 ... port=6379, 184 ... ) 185 >>> client = EpisodicMemoryClient(config) 186 """ 187 llm_config: LLMProviderConfig 188 embedder_config: LLMProviderConfig 189 # FalkorDB connection 190 host: str = "localhost" 191 port: int = 6379 192 database: str = "graphiti" 193 194 @classmethod 195 def from_dict(cls, data: Dict[str, Any]) -> "EpisodicMemoryConfig": 196 """辞書からEpisodicMemoryConfigを作成 197 198 Args: 199 data: 設定辞書 200 201 Returns: 202 EpisodicMemoryConfig instance 203 204 Example: 205 >>> config = EpisodicMemoryConfig.from_dict({ 206 ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, 207 ... "embedder": {"model": "azure/text-embedding-3-small", ...}, 208 ... "host": "localhost", 209 ... "port": 6379, 210 ... }) 211 """ 212 # LLM config 213 llm_data = data.get("llm") or data.get("llm_config") or {} 214 llm_config = LLMProviderConfig( 215 model=llm_data.get("model", ""), 216 api_key=llm_data.get("api_key", ""), 217 base_url=llm_data.get("base_url"), 218 api_version=llm_data.get("api_version"), 219 ) 220 221 # Embedder config 222 embedder_data = data.get("embedder") or data.get("embedder_config") or {} 223 embedder_config = LLMProviderConfig( 224 model=embedder_data.get("model", ""), 225 api_key=embedder_data.get("api_key", ""), 226 base_url=embedder_data.get("base_url"), 227 api_version=embedder_data.get("api_version"), 228 ) 229 230 # FalkorDB connection 231 host = data.get("host", "localhost") 232 port = data.get("port", 6379) 233 database = data.get("database", "graphiti") 234 235 return cls( 236 llm_config=llm_config, 237 embedder_config=embedder_config, 238 host=host, 239 port=port, 240 database=database, 241 )
Episodic Memory設定
FalkorDB/Graphitiベースのエピソード記憶システムの設定。
Attributes: llm_config: LLMプロバイダー設定 embedder_config: Embedderプロバイダー設定 host: FalkorDBホスト(デフォルト: localhost) port: FalkorDBポート(デフォルト: 6379) database: データベース名(デフォルト: graphiti)
Example:
from agenticstar_platform.memory import ( ... EpisodicMemoryConfig, EpisodicMemoryClient, LLMProviderConfig ... )
config = EpisodicMemoryConfig( ... llm_config=LLMProviderConfig( ... model="azure/gpt-4", ... api_key="your-key", ... base_url="https://your-resource.openai.azure.com", ... ), ... embedder_config=LLMProviderConfig( ... model="azure/text-embedding-3-small", ... api_key="your-key", ... base_url="https://your-resource.openai.azure.com", ... ), ... host="localhost", ... port=6379, ... ) client = EpisodicMemoryClient(config)
194 @classmethod 195 def from_dict(cls, data: Dict[str, Any]) -> "EpisodicMemoryConfig": 196 """辞書からEpisodicMemoryConfigを作成 197 198 Args: 199 data: 設定辞書 200 201 Returns: 202 EpisodicMemoryConfig instance 203 204 Example: 205 >>> config = EpisodicMemoryConfig.from_dict({ 206 ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, 207 ... "embedder": {"model": "azure/text-embedding-3-small", ...}, 208 ... "host": "localhost", 209 ... "port": 6379, 210 ... }) 211 """ 212 # LLM config 213 llm_data = data.get("llm") or data.get("llm_config") or {} 214 llm_config = LLMProviderConfig( 215 model=llm_data.get("model", ""), 216 api_key=llm_data.get("api_key", ""), 217 base_url=llm_data.get("base_url"), 218 api_version=llm_data.get("api_version"), 219 ) 220 221 # Embedder config 222 embedder_data = data.get("embedder") or data.get("embedder_config") or {} 223 embedder_config = LLMProviderConfig( 224 model=embedder_data.get("model", ""), 225 api_key=embedder_data.get("api_key", ""), 226 base_url=embedder_data.get("base_url"), 227 api_version=embedder_data.get("api_version"), 228 ) 229 230 # FalkorDB connection 231 host = data.get("host", "localhost") 232 port = data.get("port", 6379) 233 database = data.get("database", "graphiti") 234 235 return cls( 236 llm_config=llm_config, 237 embedder_config=embedder_config, 238 host=host, 239 port=port, 240 database=database, 241 )
辞書からEpisodicMemoryConfigを作成
Args: data: 設定辞書
Returns: EpisodicMemoryConfig instance
Example:
config = EpisodicMemoryConfig.from_dict({ ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, ... "embedder": {"model": "azure/text-embedding-3-small", ...}, ... "host": "localhost", ... "port": 6379, ... })
104@dataclass 105class LLMProviderConfig: 106 """LLMプロバイダー設定 107 108 マルチプロバイダー対応: 109 - azure: Azure OpenAI 110 - openai: OpenAI 111 - anthropic: Anthropic Claude 112 - bedrock: AWS Bedrock 113 - gemini: Google Gemini 114 - groq: Groq 115 116 Example: 117 >>> # 辞書から作成 118 >>> config = LLMProviderConfig.from_dict({ 119 ... "model": "azure/gpt-4", 120 ... "api_key": "your-api-key", 121 ... "base_url": "https://your-resource.openai.azure.com/", 122 ... "api_version": "2024-02-15-preview" 123 ... }) 124 125 >>> # TOMLファイルから作成 126 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 127 """ 128 model: str # "azure/gpt-4" or "openai/gpt-4" format 129 api_key: str 130 base_url: Optional[str] = None # For Azure 131 api_version: Optional[str] = None # For Azure 132 # AWS Bedrock specific 133 aws_access_key_id: Optional[str] = None 134 aws_secret_access_key: Optional[str] = None 135 aws_region_name: Optional[str] = None 136 137 @classmethod 138 def from_dict(cls, data: Dict[str, Any]) -> "LLMProviderConfig": 139 """辞書からLLMProviderConfigを作成 140 141 Args: 142 data: 設定辞書 143 144 Returns: 145 LLMProviderConfig instance 146 147 Example: 148 >>> config = LLMProviderConfig.from_dict({ 149 ... "model": "azure/gpt-4", 150 ... "api_key": "your-api-key", 151 ... "base_url": "https://your-resource.openai.azure.com/", 152 ... "api_version": "2024-02-15-preview" 153 ... }) 154 """ 155 return cls( 156 model=data["model"], 157 api_key=data["api_key"], 158 base_url=data.get("base_url"), 159 api_version=data.get("api_version"), 160 aws_access_key_id=data.get("aws_access_key_id"), 161 aws_secret_access_key=data.get("aws_secret_access_key"), 162 aws_region_name=data.get("aws_region_name"), 163 ) 164 165 @classmethod 166 def from_toml(cls, toml_path: str, section: str) -> "LLMProviderConfig": 167 """TOMLファイルからLLMProviderConfigを作成 168 169 Args: 170 toml_path: TOMLファイルパス 171 section: セクション名(ドット区切りでネスト対応) 172 173 Returns: 174 LLMProviderConfig instance 175 176 Example: 177 >>> # config.toml の [memory.llm] セクションを読み込み 178 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 179 180 >>> # embedder設定の読み込み 181 >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder") 182 """ 183 if not TOML_AVAILABLE: 184 raise ImportError("toml library is required: pip install toml") 185 186 with open(toml_path, "r") as f: 187 toml_data = toml.load(f) 188 189 # ネストされたセクションをドット区切りで辿る 190 data = toml_data 191 for key in section.split("."): 192 data = data[key] 193 194 return cls.from_dict(data) 195 196 @classmethod 197 def from_env(cls, prefix: str = "LLM_") -> "LLMProviderConfig": 198 """環境変数からLLMProviderConfigを作成 199 200 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 201 202 Args: 203 prefix: 環境変数名のプレフィックス(デフォルト: "LLM_") 204 205 Returns: 206 LLMProviderConfig instance 207 208 Raises: 209 ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合 210 211 Example: 212 >>> # 環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.) 213 >>> config = LLMProviderConfig.from_env() 214 215 >>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) 216 >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") 217 218 >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) 219 >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_") 220 221 Environment Variables: 222 {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" 223 {prefix}API_KEY: APIキー(必須) 224 {prefix}BASE_URL: ベースURL(Azure等で必要) 225 {prefix}API_VERSION: APIバージョン(Azure等で必要) 226 {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 227 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 228 {prefix}AWS_REGION_NAME: AWS Bedrock用 229 """ 230 def get_env(key: str) -> Optional[str]: 231 return os.environ.get(f"{prefix}{key}") or None 232 233 model = get_env("MODEL") 234 api_key = get_env("API_KEY") 235 236 if not model: 237 raise ValueError(f"環境変数 {prefix}MODEL が設定されていません") 238 if not api_key: 239 raise ValueError(f"環境変数 {prefix}API_KEY が設定されていません") 240 241 return cls( 242 model=model, 243 api_key=api_key, 244 base_url=get_env("BASE_URL"), 245 api_version=get_env("API_VERSION"), 246 aws_access_key_id=get_env("AWS_ACCESS_KEY_ID"), 247 aws_secret_access_key=get_env("AWS_SECRET_ACCESS_KEY"), 248 aws_region_name=get_env("AWS_REGION_NAME"), 249 )
LLMプロバイダー設定
マルチプロバイダー対応:
- azure: Azure OpenAI
- openai: OpenAI
- anthropic: Anthropic Claude
- bedrock: AWS Bedrock
- gemini: Google Gemini
- groq: Groq
Example:
辞書から作成
config = LLMProviderConfig.from_dict({ ... "model": "azure/gpt-4", ... "api_key": "your-api-key", ... "base_url": "https://your-resource.openai.azure.com/", ... "api_version": "2024-02-15-preview" ... })
>>> # TOMLファイルから作成 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm")
137 @classmethod 138 def from_dict(cls, data: Dict[str, Any]) -> "LLMProviderConfig": 139 """辞書からLLMProviderConfigを作成 140 141 Args: 142 data: 設定辞書 143 144 Returns: 145 LLMProviderConfig instance 146 147 Example: 148 >>> config = LLMProviderConfig.from_dict({ 149 ... "model": "azure/gpt-4", 150 ... "api_key": "your-api-key", 151 ... "base_url": "https://your-resource.openai.azure.com/", 152 ... "api_version": "2024-02-15-preview" 153 ... }) 154 """ 155 return cls( 156 model=data["model"], 157 api_key=data["api_key"], 158 base_url=data.get("base_url"), 159 api_version=data.get("api_version"), 160 aws_access_key_id=data.get("aws_access_key_id"), 161 aws_secret_access_key=data.get("aws_secret_access_key"), 162 aws_region_name=data.get("aws_region_name"), 163 )
辞書からLLMProviderConfigを作成
Args: data: 設定辞書
Returns: LLMProviderConfig instance
Example:
config = LLMProviderConfig.from_dict({ ... "model": "azure/gpt-4", ... "api_key": "your-api-key", ... "base_url": "https://your-resource.openai.azure.com/", ... "api_version": "2024-02-15-preview" ... })
165 @classmethod 166 def from_toml(cls, toml_path: str, section: str) -> "LLMProviderConfig": 167 """TOMLファイルからLLMProviderConfigを作成 168 169 Args: 170 toml_path: TOMLファイルパス 171 section: セクション名(ドット区切りでネスト対応) 172 173 Returns: 174 LLMProviderConfig instance 175 176 Example: 177 >>> # config.toml の [memory.llm] セクションを読み込み 178 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 179 180 >>> # embedder設定の読み込み 181 >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder") 182 """ 183 if not TOML_AVAILABLE: 184 raise ImportError("toml library is required: pip install toml") 185 186 with open(toml_path, "r") as f: 187 toml_data = toml.load(f) 188 189 # ネストされたセクションをドット区切りで辿る 190 data = toml_data 191 for key in section.split("."): 192 data = data[key] 193 194 return cls.from_dict(data)
TOMLファイルからLLMProviderConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: LLMProviderConfig instance
Example:
config.toml の [memory.llm] セクションを読み込み
config = LLMProviderConfig.from_toml("config.toml", section="memory.llm")
>>> # embedder設定の読み込み >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder")
196 @classmethod 197 def from_env(cls, prefix: str = "LLM_") -> "LLMProviderConfig": 198 """環境変数からLLMProviderConfigを作成 199 200 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 201 202 Args: 203 prefix: 環境変数名のプレフィックス(デフォルト: "LLM_") 204 205 Returns: 206 LLMProviderConfig instance 207 208 Raises: 209 ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合 210 211 Example: 212 >>> # 環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.) 213 >>> config = LLMProviderConfig.from_env() 214 215 >>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) 216 >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") 217 218 >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) 219 >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_") 220 221 Environment Variables: 222 {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" 223 {prefix}API_KEY: APIキー(必須) 224 {prefix}BASE_URL: ベースURL(Azure等で必要) 225 {prefix}API_VERSION: APIバージョン(Azure等で必要) 226 {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 227 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 228 {prefix}AWS_REGION_NAME: AWS Bedrock用 229 """ 230 def get_env(key: str) -> Optional[str]: 231 return os.environ.get(f"{prefix}{key}") or None 232 233 model = get_env("MODEL") 234 api_key = get_env("API_KEY") 235 236 if not model: 237 raise ValueError(f"環境変数 {prefix}MODEL が設定されていません") 238 if not api_key: 239 raise ValueError(f"環境変数 {prefix}API_KEY が設定されていません") 240 241 return cls( 242 model=model, 243 api_key=api_key, 244 base_url=get_env("BASE_URL"), 245 api_version=get_env("API_VERSION"), 246 aws_access_key_id=get_env("AWS_ACCESS_KEY_ID"), 247 aws_secret_access_key=get_env("AWS_SECRET_ACCESS_KEY"), 248 aws_region_name=get_env("AWS_REGION_NAME"), 249 )
環境変数からLLMProviderConfigを作成
環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。
Args: prefix: 環境変数名のプレフィックス(デフォルト: "LLM_")
Returns: LLMProviderConfig instance
Raises: ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合
Example:
環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.)
config = LLMProviderConfig.from_env()
>>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_")Environment Variables: {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" {prefix}API_KEY: APIキー(必須) {prefix}BASE_URL: ベースURL(Azure等で必要) {prefix}API_VERSION: APIバージョン(Azure等で必要) {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 {prefix}AWS_REGION_NAME: AWS Bedrock用
86def normalize_group_id(group_id: str) -> str: 87 """ 88 group_idをFalkorDB互換形式に正規化 89 90 FalkorDBの制約: 91 - 英数字、ハイフン、アンダースコアのみ使用可能 92 - 日本語等の特殊文字はハッシュ化 93 94 Args: 95 group_id: 元のgroup_id(日本語、スペース等を含む可能性あり) 96 97 Returns: 98 正規化されたgroup_id 99 100 Examples: 101 "org-ソリューションエンジニアリング本部" → "org-a1b2c3d4e5f6" 102 "role-プロダクトマネージャー" → "role-9g8h7i6j5k4l" 103 "user-12345" → "user-12345" (変更なし) 104 """ 105 # Already valid? (alphanumeric, dash, underscore only) 106 if re.match(r'^[a-zA-Z0-9_-]+$', group_id): 107 return group_id 108 109 # Extract prefix (org-/role-/user-/global-) 110 prefix_match = re.match(r'^(org-|role-|user-|global-)', group_id) 111 if prefix_match: 112 prefix = prefix_match.group(1) 113 suffix = group_id[len(prefix):] 114 else: 115 prefix = "" 116 suffix = group_id 117 118 # Hash the suffix if it contains non-ASCII or special chars 119 if not re.match(r'^[a-zA-Z0-9_-]+$', suffix): 120 # SHA256 hash (first 12 chars for readability) 121 hash_hex = hashlib.sha256(suffix.encode('utf-8')).hexdigest()[:12] 122 normalized = f"{prefix}{hash_hex}" 123 logger.debug(f"Normalized group_id: '{group_id}' → '{normalized}'") 124 return normalized 125 126 # Suffix is valid, just replace spaces with dashes 127 normalized = f"{prefix}{suffix}".replace(" ", "-").replace("_", "-") 128 return normalized
group_idをFalkorDB互換形式に正規化
FalkorDBの制約:
- 英数字、ハイフン、アンダースコアのみ使用可能
- 日本語等の特殊文字はハッシュ化
Args: group_id: 元のgroup_id(日本語、スペース等を含む可能性あり)
Returns: 正規化されたgroup_id
Examples: "org-ソリューションエンジニアリング本部" → "org-a1b2c3d4e5f6" "role-プロダクトマネージャー" → "role-9g8h7i6j5k4l" "user-12345" → "user-12345" (変更なし)
22class StorageProvider(str, Enum): 23 """ストレージプロバイダー種別""" 24 AZURE_BLOB = "azure_blob" 25 AWS_S3 = "aws_s3" 26 GCS = "gcs"
ストレージプロバイダー種別
97@dataclass 98class StorageConfig: 99 """共通ストレージ設定""" 100 provider: StorageProvider 101 bucket_name: str # Azure: container_name, S3: bucket, GCS: bucket 102 enabled: bool = True 103 # 共通オプション 104 max_file_size: int = 100 * 1024 * 1024 # 100MB 105 auto_create_bucket: bool = False 106 prefix: str = "" # オブジェクト名プレフィックス 107 custom_domain: Optional[str] = None 108 # 接続設定 109 connection_timeout: int = 30 110 read_timeout: int = 30
共通ストレージ設定
113@dataclass 114class AzureBlobConfig: 115 """Azure Blob Storage設定 116 117 Example: 118 >>> # 辞書から作成 119 >>> config = AzureBlobConfig.from_dict({ 120 ... "bucket_name": "my-container", 121 ... "connection_string": "DefaultEndpointsProtocol=https;...", 122 ... "prefix": "uploads/" 123 ... }) 124 >>> 125 >>> # クライアント初期化 126 >>> client = AzureBlobStorageClient(config) 127 128 Note: 129 connection_stringはrepr=Falseでログ出力から除外されます。 130 """ 131 bucket_name: str # container_name 132 connection_string: str = field(repr=False) # シークレット: reprから除外 133 # 共通オプション 134 enabled: bool = True 135 max_file_size: int = 100 * 1024 * 1024 # 100MB 136 auto_create_bucket: bool = False 137 prefix: str = "" 138 custom_domain: Optional[str] = None 139 connection_timeout: int = 30 140 read_timeout: int = 30 141 # Azure固有オプション 142 max_block_size: int = 4 * 1024 * 1024 # 4MB 143 max_single_put_size: int = 64 * 1024 * 1024 # 64MB 144 145 @property 146 def provider(self) -> StorageProvider: 147 return StorageProvider.AZURE_BLOB 148 149 def __str__(self) -> str: 150 return f"AzureBlobConfig(bucket_name={self.bucket_name!r}, connection_string=***)" 151 152 @classmethod 153 def from_dict(cls, data: Dict[str, Any]) -> "AzureBlobConfig": 154 """辞書からAzureBlobConfigを作成 155 156 Args: 157 data: 設定辞書 158 159 Returns: 160 AzureBlobConfig instance 161 162 Example: 163 >>> config = AzureBlobConfig.from_dict({ 164 ... "bucket_name": "my-container", 165 ... "connection_string": "DefaultEndpointsProtocol=https;..." 166 ... }) 167 """ 168 return cls( 169 bucket_name=data["bucket_name"], 170 connection_string=data["connection_string"], 171 enabled=data.get("enabled", True), 172 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 173 auto_create_bucket=data.get("auto_create_bucket", False), 174 prefix=data.get("prefix", ""), 175 custom_domain=data.get("custom_domain"), 176 connection_timeout=data.get("connection_timeout", 30), 177 read_timeout=data.get("read_timeout", 30), 178 max_block_size=data.get("max_block_size", 4 * 1024 * 1024), 179 max_single_put_size=data.get("max_single_put_size", 64 * 1024 * 1024), 180 )
Azure Blob Storage設定
Example:
辞書から作成
config = AzureBlobConfig.from_dict({ ... "bucket_name": "my-container", ... "connection_string": "DefaultEndpointsProtocol=https;...", ... "prefix": "uploads/" ... })
クライアント初期化
client = AzureBlobStorageClient(config)
Note: connection_stringはrepr=Falseでログ出力から除外されます。
152 @classmethod 153 def from_dict(cls, data: Dict[str, Any]) -> "AzureBlobConfig": 154 """辞書からAzureBlobConfigを作成 155 156 Args: 157 data: 設定辞書 158 159 Returns: 160 AzureBlobConfig instance 161 162 Example: 163 >>> config = AzureBlobConfig.from_dict({ 164 ... "bucket_name": "my-container", 165 ... "connection_string": "DefaultEndpointsProtocol=https;..." 166 ... }) 167 """ 168 return cls( 169 bucket_name=data["bucket_name"], 170 connection_string=data["connection_string"], 171 enabled=data.get("enabled", True), 172 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 173 auto_create_bucket=data.get("auto_create_bucket", False), 174 prefix=data.get("prefix", ""), 175 custom_domain=data.get("custom_domain"), 176 connection_timeout=data.get("connection_timeout", 30), 177 read_timeout=data.get("read_timeout", 30), 178 max_block_size=data.get("max_block_size", 4 * 1024 * 1024), 179 max_single_put_size=data.get("max_single_put_size", 64 * 1024 * 1024), 180 )
辞書からAzureBlobConfigを作成
Args: data: 設定辞書
Returns: AzureBlobConfig instance
Example:
config = AzureBlobConfig.from_dict({ ... "bucket_name": "my-container", ... "connection_string": "DefaultEndpointsProtocol=https;..." ... })
183@dataclass 184class S3Config: 185 """AWS S3設定 186 187 Example: 188 >>> # 辞書から作成 189 >>> config = S3Config.from_dict({ 190 ... "bucket_name": "my-bucket", 191 ... "aws_access_key_id": "AKIAXXXXXXX", 192 ... "aws_secret_access_key": "secretkey", 193 ... "region_name": "ap-northeast-1" 194 ... }) 195 >>> 196 >>> # クライアント初期化 197 >>> client = S3StorageClient(config) 198 199 Note: 200 aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。 201 """ 202 bucket_name: str 203 aws_access_key_id: str = field(repr=False) # シークレット: reprから除外 204 aws_secret_access_key: str = field(repr=False) # シークレット: reprから除外 205 region_name: str = "us-east-1" 206 # 共通オプション 207 enabled: bool = True 208 max_file_size: int = 100 * 1024 * 1024 # 100MB 209 auto_create_bucket: bool = False 210 prefix: str = "" 211 custom_domain: Optional[str] = None 212 connection_timeout: int = 30 213 read_timeout: int = 30 214 # S3固有オプション 215 endpoint_url: Optional[str] = None # MinIO等のS3互換ストレージ用 216 217 @property 218 def provider(self) -> StorageProvider: 219 return StorageProvider.AWS_S3 220 221 def __str__(self) -> str: 222 return f"S3Config(bucket_name={self.bucket_name!r}, region_name={self.region_name!r}, credentials=***)" 223 224 @classmethod 225 def from_dict(cls, data: Dict[str, Any]) -> "S3Config": 226 """辞書からS3Configを作成 227 228 Args: 229 data: 設定辞書 230 231 Returns: 232 S3Config instance 233 234 Example: 235 >>> config = S3Config.from_dict({ 236 ... "bucket_name": "my-bucket", 237 ... "aws_access_key_id": "AKIAXXXXXXX", 238 ... "aws_secret_access_key": "secretkey", 239 ... "region_name": "ap-northeast-1" 240 ... }) 241 """ 242 return cls( 243 bucket_name=data["bucket_name"], 244 aws_access_key_id=data["aws_access_key_id"], 245 aws_secret_access_key=data["aws_secret_access_key"], 246 region_name=data.get("region_name", "us-east-1"), 247 enabled=data.get("enabled", True), 248 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 249 auto_create_bucket=data.get("auto_create_bucket", False), 250 prefix=data.get("prefix", ""), 251 custom_domain=data.get("custom_domain"), 252 connection_timeout=data.get("connection_timeout", 30), 253 read_timeout=data.get("read_timeout", 30), 254 endpoint_url=data.get("endpoint_url"), 255 )
AWS S3設定
Example:
辞書から作成
config = S3Config.from_dict({ ... "bucket_name": "my-bucket", ... "aws_access_key_id": "AKIAXXXXXXX", ... "aws_secret_access_key": "secretkey", ... "region_name": "ap-northeast-1" ... })
クライアント初期化
client = S3StorageClient(config)
Note: aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。
224 @classmethod 225 def from_dict(cls, data: Dict[str, Any]) -> "S3Config": 226 """辞書からS3Configを作成 227 228 Args: 229 data: 設定辞書 230 231 Returns: 232 S3Config instance 233 234 Example: 235 >>> config = S3Config.from_dict({ 236 ... "bucket_name": "my-bucket", 237 ... "aws_access_key_id": "AKIAXXXXXXX", 238 ... "aws_secret_access_key": "secretkey", 239 ... "region_name": "ap-northeast-1" 240 ... }) 241 """ 242 return cls( 243 bucket_name=data["bucket_name"], 244 aws_access_key_id=data["aws_access_key_id"], 245 aws_secret_access_key=data["aws_secret_access_key"], 246 region_name=data.get("region_name", "us-east-1"), 247 enabled=data.get("enabled", True), 248 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 249 auto_create_bucket=data.get("auto_create_bucket", False), 250 prefix=data.get("prefix", ""), 251 custom_domain=data.get("custom_domain"), 252 connection_timeout=data.get("connection_timeout", 30), 253 read_timeout=data.get("read_timeout", 30), 254 endpoint_url=data.get("endpoint_url"), 255 )
辞書からS3Configを作成
Args: data: 設定辞書
Returns: S3Config instance
Example:
config = S3Config.from_dict({ ... "bucket_name": "my-bucket", ... "aws_access_key_id": "AKIAXXXXXXX", ... "aws_secret_access_key": "secretkey", ... "region_name": "ap-northeast-1" ... })
258@dataclass 259class GCSConfig: 260 """Google Cloud Storage設定 261 262 Example: 263 >>> # 辞書から作成 264 >>> config = GCSConfig.from_dict({ 265 ... "bucket_name": "my-bucket", 266 ... "project_id": "my-project", 267 ... "credentials_path": "/path/to/service-account.json" 268 ... }) 269 >>> 270 >>> # クライアント初期化 271 >>> client = GCSStorageClient(config) 272 273 Note: 274 credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。 275 """ 276 bucket_name: str 277 project_id: str 278 # 共通オプション 279 enabled: bool = True 280 max_file_size: int = 100 * 1024 * 1024 # 100MB 281 auto_create_bucket: bool = False 282 prefix: str = "" 283 custom_domain: Optional[str] = None 284 connection_timeout: int = 30 285 read_timeout: int = 30 286 # GCS固有オプション(シークレット: reprから除外) 287 credentials_path: Optional[str] = field(default=None, repr=False) 288 credentials_json: Optional[str] = field(default=None, repr=False) 289 290 @property 291 def provider(self) -> StorageProvider: 292 return StorageProvider.GCS 293 294 def __str__(self) -> str: 295 return f"GCSConfig(bucket_name={self.bucket_name!r}, project_id={self.project_id!r}, credentials=***)" 296 297 @classmethod 298 def from_dict(cls, data: Dict[str, Any]) -> "GCSConfig": 299 """辞書からGCSConfigを作成 300 301 Args: 302 data: 設定辞書 303 304 Returns: 305 GCSConfig instance 306 307 Example: 308 >>> config = GCSConfig.from_dict({ 309 ... "bucket_name": "my-bucket", 310 ... "project_id": "my-project", 311 ... "credentials_path": "/path/to/service-account.json" 312 ... }) 313 """ 314 return cls( 315 bucket_name=data["bucket_name"], 316 project_id=data["project_id"], 317 enabled=data.get("enabled", True), 318 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 319 auto_create_bucket=data.get("auto_create_bucket", False), 320 prefix=data.get("prefix", ""), 321 custom_domain=data.get("custom_domain"), 322 connection_timeout=data.get("connection_timeout", 30), 323 read_timeout=data.get("read_timeout", 30), 324 credentials_path=data.get("credentials_path"), 325 credentials_json=data.get("credentials_json"), 326 )
Google Cloud Storage設定
Example:
辞書から作成
config = GCSConfig.from_dict({ ... "bucket_name": "my-bucket", ... "project_id": "my-project", ... "credentials_path": "/path/to/service-account.json" ... })
クライアント初期化
client = GCSStorageClient(config)
Note: credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。
297 @classmethod 298 def from_dict(cls, data: Dict[str, Any]) -> "GCSConfig": 299 """辞書からGCSConfigを作成 300 301 Args: 302 data: 設定辞書 303 304 Returns: 305 GCSConfig instance 306 307 Example: 308 >>> config = GCSConfig.from_dict({ 309 ... "bucket_name": "my-bucket", 310 ... "project_id": "my-project", 311 ... "credentials_path": "/path/to/service-account.json" 312 ... }) 313 """ 314 return cls( 315 bucket_name=data["bucket_name"], 316 project_id=data["project_id"], 317 enabled=data.get("enabled", True), 318 max_file_size=data.get("max_file_size", 100 * 1024 * 1024), 319 auto_create_bucket=data.get("auto_create_bucket", False), 320 prefix=data.get("prefix", ""), 321 custom_domain=data.get("custom_domain"), 322 connection_timeout=data.get("connection_timeout", 30), 323 read_timeout=data.get("read_timeout", 30), 324 credentials_path=data.get("credentials_path"), 325 credentials_json=data.get("credentials_json"), 326 )
辞書からGCSConfigを作成
Args: data: 設定辞書
Returns: GCSConfig instance
Example:
config = GCSConfig.from_dict({ ... "bucket_name": "my-bucket", ... "project_id": "my-project", ... "credentials_path": "/path/to/service-account.json" ... })
39class AzureBlobStorageClient(StorageClientBase): 40 """ 41 Azure Blob Storageクライアント 42 43 Features: 44 - ファイルアップロード・ダウンロード 45 - オブジェクト一覧・削除 46 - カスタムドメイン対応 47 - 日本語ファイル名対応(URLエンコード) 48 - メタデータサニタイズ(Base64エンコード) 49 50 Example: 51 config = AzureBlobConfig( 52 bucket_name="my-container", 53 connection_string="DefaultEndpointsProtocol=https;...", 54 ) 55 client = AzureBlobStorageClient(config) 56 result = await client.upload_file("/path/to/file.txt") 57 """ 58 59 def __init__(self, config: AzureBlobConfig): 60 """ 61 Initialize Azure Blob Storage Client 62 63 Args: 64 config: AzureBlobConfig instance 65 66 Raises: 67 StorageConfigError: 設定が不正な場合 68 StorageConnectionError: 接続に失敗した場合 69 """ 70 super().__init__(config) 71 72 if not AZURE_BLOB_AVAILABLE: 73 raise StorageConfigError( 74 "azure-storage-blob package is not installed. " 75 "Install with: pip install azure-storage-blob" 76 ) 77 78 if not config.connection_string: 79 raise StorageConfigError( 80 "Azure Blob connection_string is required" 81 ) 82 83 if not config.bucket_name: 84 raise StorageConfigError( 85 "Azure Blob bucket_name (container_name) is required" 86 ) 87 88 self._azure_config: AzureBlobConfig = config 89 self._blob_service_client: Optional[BlobServiceClient] = None 90 91 try: 92 self._blob_service_client = BlobServiceClient.from_connection_string( 93 conn_str=self._get_api_key(config.connection_string), 94 max_block_size=config.max_block_size, 95 max_single_put_size=config.max_single_put_size, 96 connection_timeout=config.connection_timeout, 97 read_timeout=config.read_timeout, 98 ) 99 self._enabled = config.enabled 100 logger.info( 101 f"AzureBlobStorageClient initialized - " 102 f"container={config.bucket_name}" 103 ) 104 except Exception as e: 105 logger.error(f"Failed to initialize Azure Blob client: {e}") 106 raise StorageConnectionError( 107 f"Failed to connect to Azure Blob Storage: {e}" 108 ) from e 109 110 @property 111 def provider(self) -> StorageProvider: 112 return StorageProvider.AZURE_BLOB 113 114 @property 115 def container_name(self) -> str: 116 """コンテナ名(bucket_nameのエイリアス)""" 117 return self._azure_config.bucket_name 118 119 def _get_container_client(self): 120 """コンテナクライアントを取得""" 121 return self._blob_service_client.get_container_client(self.container_name) 122 123 def _get_blob_client(self, blob_name: str): 124 """Blobクライアントを取得""" 125 return self._blob_service_client.get_blob_client( 126 container=self.container_name, 127 blob=blob_name, 128 ) 129 130 def _encode_blob_name(self, blob_name: str) -> str: 131 """Blob名をURLエンコード(日本語対応)""" 132 try: 133 if blob_name != unquote(blob_name): 134 return blob_name 135 return quote(blob_name, safe='/') 136 except Exception: 137 return quote(blob_name, safe='/') 138 139 def _generate_public_url(self, blob_url: str) -> str: 140 """公開URLを生成(カスタムドメイン対応)""" 141 custom_domain = self._azure_config.custom_domain 142 if not custom_domain: 143 return blob_url 144 145 try: 146 parsed = urlparse(blob_url) 147 return f"https://{custom_domain}{parsed.path}" 148 except Exception as e: 149 logger.warning(f"Failed to generate custom domain URL: {e}") 150 return blob_url 151 152 async def ensure_bucket_exists(self) -> bool: 153 """コンテナの存在確認・作成""" 154 if not self._enabled: 155 return False 156 157 try: 158 container_client = self._get_container_client() 159 160 exists = await asyncio.get_event_loop().run_in_executor( 161 None, container_client.exists 162 ) 163 164 if not exists: 165 if self._azure_config.auto_create_bucket: 166 await asyncio.get_event_loop().run_in_executor( 167 None, container_client.create_container 168 ) 169 logger.info(f"Created container: {self.container_name}") 170 else: 171 logger.error( 172 f"Container {self.container_name} does not exist " 173 "and auto_create_bucket is disabled" 174 ) 175 return False 176 177 return True 178 179 except Exception as e: 180 logger.error(f"Failed to ensure container exists: {e}") 181 return False 182 183 async def upload_file( 184 self, 185 file_path: str, 186 object_name: Optional[str] = None, 187 prefix: Optional[str] = None, 188 metadata: Optional[Dict[str, Any]] = None, 189 ) -> UploadResult: 190 """ 191 ファイルをAzure Blobにアップロード 192 193 Args: 194 file_path: アップロードするファイルのローカルパス 195 object_name: Blob名(省略時はファイル名が使用される) 196 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 197 metadata: 追加メタデータ(日本語はBase64エンコードされる) 198 199 Returns: 200 UploadResult: アップロード結果 201 - success: 成功/失敗 202 - object_name: アップロード先のBlob名 203 - object_url: 公開URL 204 - file_size: ファイルサイズ(bytes) 205 - content_type: MIMEタイプ 206 207 Example: 208 >>> # 基本的なアップロード 209 >>> result = await client.upload_file("/tmp/report.pdf") 210 >>> if result.success: 211 ... print(f"Uploaded: {result.object_url}") 212 >>> 213 >>> # プレフィックス指定でフォルダ構造を作成 214 >>> result = await client.upload_file( 215 ... file_path="/tmp/image.png", 216 ... prefix="uploads/2024/01", 217 ... metadata={"author": "user-123"} 218 ... ) 219 """ 220 if not self._enabled: 221 return UploadResult( 222 success=False, 223 error="Azure Blob Storage is not enabled", 224 error_code="STORAGE_NOT_ENABLED", 225 ) 226 227 try: 228 # ファイルサイズ検証 229 self._validate_file_size(file_path) 230 231 file_path_obj = Path(file_path) 232 233 # オブジェクト名決定 234 if not object_name: 235 object_name = file_path_obj.name 236 237 # URLエンコード 238 safe_object_name = self._encode_blob_name(object_name) 239 240 # プレフィックス適用 241 config_prefix = self._azure_config.prefix 242 if prefix: 243 full_object_name = f"{config_prefix}{prefix}/{safe_object_name}" 244 elif config_prefix: 245 full_object_name = f"{config_prefix}{safe_object_name}" 246 else: 247 full_object_name = safe_object_name 248 249 # コンテナ確認 250 if not await self.ensure_bucket_exists(): 251 return UploadResult( 252 success=False, 253 error="Container is not available", 254 error_code="CONTAINER_NOT_AVAILABLE", 255 ) 256 257 # Content-Type設定 258 content_type = self._get_content_type(file_path) 259 content_settings = ContentSettings(content_type=content_type) 260 261 # メタデータサニタイズ 262 safe_metadata = self._sanitize_metadata(metadata) 263 264 # アップロード実行 265 blob_client = self._get_blob_client(full_object_name) 266 file_size = file_path_obj.stat().st_size 267 268 with open(file_path_obj, 'rb') as data: 269 await asyncio.get_event_loop().run_in_executor( 270 None, 271 lambda: blob_client.upload_blob( 272 data, 273 overwrite=True, 274 content_settings=content_settings, 275 metadata=safe_metadata if safe_metadata else None, 276 ) 277 ) 278 279 # URL生成 280 blob_url = self._generate_public_url(blob_client.url) 281 282 logger.info("Successfully uploaded file to Azure Blob") 283 284 return UploadResult( 285 success=True, 286 object_name=full_object_name, 287 object_url=blob_url, 288 file_size=file_size, 289 content_type=content_type, 290 metadata={ 291 "uploaded_at": time.time(), 292 "prefix": prefix, 293 **(safe_metadata or {}), 294 }, 295 ) 296 297 except StorageOperationError: 298 raise 299 except Exception as e: 300 logger.error(f"Failed to upload file: {e}") 301 return UploadResult( 302 success=False, 303 error=f"Upload failed: {e}", 304 error_code="UPLOAD_FAILED", 305 ) 306 307 async def download_file( 308 self, 309 object_name: str, 310 download_path: str, 311 ) -> DownloadResult: 312 """ 313 ファイルをダウンロード(ストリーミング) 314 315 大きなファイルでもメモリを圧迫しないよう、 316 チャンク単位でストリーミングダウンロードします。 317 318 Args: 319 object_name: ダウンロードするBlob名 320 download_path: ダウンロード先のローカルパス 321 322 Returns: 323 DownloadResult: ダウンロード結果 324 - success: 成功/失敗 325 - local_path: ダウンロード先パス 326 - file_size: ファイルサイズ(bytes) 327 - error: エラーメッセージ(失敗時) 328 329 Example: 330 >>> # Blobをダウンロード 331 >>> result = await client.download_file( 332 ... object_name="uploads/2024/01/report.pdf", 333 ... download_path="/tmp/downloaded_report.pdf" 334 ... ) 335 >>> if result.success: 336 ... print(f"Downloaded to: {result.local_path}") 337 ... print(f"Size: {result.file_size} bytes") 338 >>> else: 339 ... print(f"Error: {result.error}") 340 341 Note: 342 ダウンロード先ディレクトリが存在しない場合は自動作成されます。 343 """ 344 if not self._enabled: 345 return DownloadResult( 346 success=False, 347 error="Azure Blob Storage is not enabled", 348 error_code="STORAGE_NOT_ENABLED", 349 ) 350 351 try: 352 blob_client = self._get_blob_client(object_name) 353 354 # 存在確認 355 exists = await asyncio.get_event_loop().run_in_executor( 356 None, blob_client.exists 357 ) 358 if not exists: 359 return DownloadResult( 360 success=False, 361 object_name=object_name, 362 error=f"Blob not found: {object_name}", 363 error_code="BLOB_NOT_FOUND", 364 ) 365 366 # ダウンロード先ディレクトリ作成 367 download_path_obj = Path(download_path) 368 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 369 370 # ストリーミングダウンロード(メモリ効率化) 371 def _streaming_download(): 372 with open(download_path_obj, 'wb') as f: 373 download_stream = blob_client.download_blob() 374 # chunks()でチャンク単位でダウンロード 375 for chunk in download_stream.chunks(): 376 f.write(chunk) 377 378 await asyncio.get_event_loop().run_in_executor( 379 None, _streaming_download 380 ) 381 382 file_size = download_path_obj.stat().st_size 383 logger.info("Successfully downloaded blob") 384 385 return DownloadResult( 386 success=True, 387 object_name=object_name, 388 local_path=str(download_path_obj), 389 file_size=file_size, 390 ) 391 392 except Exception as e: 393 logger.error(f"Failed to download blob: {e}") 394 return DownloadResult( 395 success=False, 396 object_name=object_name, 397 error=f"Download failed: {e}", 398 error_code="DOWNLOAD_FAILED", 399 ) 400 401 async def list_objects( 402 self, 403 prefix: str = "", 404 max_results: Optional[int] = None, 405 ) -> ListResult: 406 """ 407 オブジェクト一覧を取得 408 409 指定したプレフィックスに一致するBlobの一覧を取得します。 410 411 Args: 412 prefix: プレフィックスフィルタ(フォルダパスとして機能) 413 max_results: 最大取得数(省略時は全件取得) 414 415 Returns: 416 ListResult: 一覧結果 417 - success: 成功/失敗 418 - objects: ObjectInfoのリスト 419 - count: 取得件数 420 421 Example: 422 >>> # 全オブジェクト一覧 423 >>> result = await client.list_objects() 424 >>> print(f"Total: {result.count} objects") 425 >>> 426 >>> # プレフィックスでフィルタ 427 >>> result = await client.list_objects( 428 ... prefix="uploads/2024/", 429 ... max_results=100 430 ... ) 431 >>> for obj in result.objects: 432 ... print(f"{obj.name}: {obj.size} bytes") 433 """ 434 if not self._enabled: 435 return ListResult( 436 success=False, 437 error="Azure Blob Storage is not enabled", 438 error_code="STORAGE_NOT_ENABLED", 439 ) 440 441 try: 442 container_client = self._get_container_client() 443 444 # プレフィックス適用 445 full_prefix = f"{self._azure_config.prefix}{prefix}" 446 447 blob_list = await asyncio.get_event_loop().run_in_executor( 448 None, 449 lambda: list(container_client.list_blobs(name_starts_with=full_prefix)) 450 ) 451 452 objects = [] 453 for blob in blob_list: 454 if max_results and len(objects) >= max_results: 455 break 456 objects.append(ObjectInfo( 457 name=blob.name, 458 size=blob.size, 459 last_modified=blob.last_modified.isoformat() if blob.last_modified else None, 460 content_type=blob.content_settings.content_type if blob.content_settings else None, 461 )) 462 463 logger.info(f"Listed {len(objects)} blobs") 464 465 return ListResult( 466 success=True, 467 objects=objects, 468 count=len(objects), 469 prefix=prefix, 470 ) 471 472 except Exception as e: 473 logger.error(f"Failed to list blobs: {e}") 474 return ListResult( 475 success=False, 476 error=f"List failed: {e}", 477 error_code="LIST_FAILED", 478 ) 479 480 async def delete_object(self, object_name: str) -> bool: 481 """ 482 オブジェクトを削除 483 484 Args: 485 object_name: 削除するBlob名 486 487 Returns: 488 成功した場合True 489 """ 490 if not self._enabled: 491 return False 492 493 try: 494 blob_client = self._get_blob_client(object_name) 495 496 await asyncio.get_event_loop().run_in_executor( 497 None, 498 lambda: blob_client.delete_blob(delete_snapshots="include") 499 ) 500 501 logger.info("Successfully deleted blob") 502 return True 503 504 except Exception as e: 505 logger.error(f"Failed to delete blob: {e}") 506 return False 507 508 async def object_exists(self, object_name: str) -> bool: 509 """ 510 オブジェクトの存在確認 511 512 Args: 513 object_name: 確認するBlob名 514 515 Returns: 516 存在する場合True 517 """ 518 if not self._enabled: 519 return False 520 521 try: 522 blob_client = self._get_blob_client(object_name) 523 return await asyncio.get_event_loop().run_in_executor( 524 None, blob_client.exists 525 ) 526 except Exception as e: 527 logger.error(f"Failed to check blob existence: {e}") 528 return False 529 530 async def get_object_url( 531 self, 532 object_name: str, 533 expires_in: Optional[int] = None, 534 ) -> str: 535 """ 536 オブジェクトのURLを取得 537 538 Args: 539 object_name: Blob名 540 expires_in: SAS有効期限(秒)- 未実装、将来拡張用 541 542 Returns: 543 オブジェクトURL 544 """ 545 if not self._enabled: 546 return "" 547 548 blob_client = self._get_blob_client(object_name) 549 return self._generate_public_url(blob_client.url) 550 551 def _sanitize_path(self, path: str) -> str: 552 """パスをサニタイズ(パストラバーサル防止) 553 554 Args: 555 path: サニタイズするパス 556 557 Returns: 558 サニタイズされたパス(危険な文字を除去) 559 """ 560 import re 561 # パストラバーサル攻撃防止: ../ や ..\\ を除去 562 # 先頭の / も除去 563 sanitized = re.sub(r'\.\.[\\/]', '', path) 564 sanitized = re.sub(r'^[\\/]+', '', sanitized) 565 # NULL文字や制御文字を除去 566 sanitized = re.sub(r'[\x00-\x1f]', '', sanitized) 567 return sanitized 568 569 async def download_objects_by_prefix( 570 self, 571 prefix: str, 572 download_dir: str, 573 ) -> Dict[str, Any]: 574 """ 575 プレフィックスに一致するオブジェクトをすべてダウンロード 576 577 Args: 578 prefix: プレフィックス 579 download_dir: ダウンロード先ディレクトリ 580 581 Returns: 582 ダウンロード結果サマリー 583 584 Security: 585 パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 586 ダウンロード先がdownload_dir配下であることを検証します。 587 """ 588 list_result = await self.list_objects(prefix) 589 if not list_result.success: 590 return { 591 "success": False, 592 "error": list_result.error, 593 "downloaded": [], 594 "failed": [], 595 } 596 597 download_dir_obj = Path(download_dir).resolve() 598 downloaded = [] 599 failed = [] 600 skipped = [] 601 602 for obj in list_result.objects: 603 # プレフィックス除去 604 relative_path = obj.name 605 if obj.name.startswith(prefix): 606 relative_path = obj.name[len(prefix):] 607 608 # パスサニタイズ(パストラバーサル防止) 609 relative_path = self._sanitize_path(relative_path) 610 611 if not relative_path: 612 skipped.append({ 613 "object_name": obj.name, 614 "reason": "Invalid path after sanitization", 615 }) 616 continue 617 618 download_path = (download_dir_obj / relative_path).resolve() 619 620 # パストラバーサル検証: download_dir配下であることを確認 621 try: 622 download_path.relative_to(download_dir_obj) 623 except ValueError: 624 logger.warning( 625 f"Path traversal attempt detected, skipping object" 626 ) 627 skipped.append({ 628 "object_name": obj.name, 629 "reason": "Path traversal detected", 630 }) 631 continue 632 633 result = await self.download_file(obj.name, str(download_path)) 634 if result.success: 635 downloaded.append({ 636 "object_name": obj.name, 637 "local_path": result.local_path, 638 "file_size": result.file_size, 639 }) 640 else: 641 failed.append({ 642 "object_name": obj.name, 643 "error": result.error, 644 }) 645 646 return { 647 "success": True, 648 "downloaded": downloaded, 649 "failed": failed, 650 "skipped": skipped, 651 "success_count": len(downloaded), 652 "failed_count": len(failed), 653 "skipped_count": len(skipped), 654 } 655 656 async def close(self) -> None: 657 """クライアントリソースをクローズ 658 659 Note: 660 例外が発生してもリソースを確実にクリーンアップします。 661 """ 662 try: 663 if self._blob_service_client: 664 # BlobServiceClientはclose()メソッドを持たないが、 665 # 将来のバージョンで追加される可能性があるため確認 666 if hasattr(self._blob_service_client, 'close'): 667 self._blob_service_client.close() 668 except Exception as e: 669 logger.warning(f"Error closing Azure Blob client: {e}") 670 finally: 671 self._blob_service_client = None 672 self._enabled = False 673 logger.debug("AzureBlobStorageClient closed")
Azure Blob Storageクライアント
Features:
- ファイルアップロード・ダウンロード
- オブジェクト一覧・削除
- カスタムドメイン対応
- 日本語ファイル名対応(URLエンコード)
- メタデータサニタイズ(Base64エンコード)
Example: config = AzureBlobConfig( bucket_name="my-container", connection_string="DefaultEndpointsProtocol=https;...", ) client = AzureBlobStorageClient(config) result = await client.upload_file("/path/to/file.txt")
59 def __init__(self, config: AzureBlobConfig): 60 """ 61 Initialize Azure Blob Storage Client 62 63 Args: 64 config: AzureBlobConfig instance 65 66 Raises: 67 StorageConfigError: 設定が不正な場合 68 StorageConnectionError: 接続に失敗した場合 69 """ 70 super().__init__(config) 71 72 if not AZURE_BLOB_AVAILABLE: 73 raise StorageConfigError( 74 "azure-storage-blob package is not installed. " 75 "Install with: pip install azure-storage-blob" 76 ) 77 78 if not config.connection_string: 79 raise StorageConfigError( 80 "Azure Blob connection_string is required" 81 ) 82 83 if not config.bucket_name: 84 raise StorageConfigError( 85 "Azure Blob bucket_name (container_name) is required" 86 ) 87 88 self._azure_config: AzureBlobConfig = config 89 self._blob_service_client: Optional[BlobServiceClient] = None 90 91 try: 92 self._blob_service_client = BlobServiceClient.from_connection_string( 93 conn_str=self._get_api_key(config.connection_string), 94 max_block_size=config.max_block_size, 95 max_single_put_size=config.max_single_put_size, 96 connection_timeout=config.connection_timeout, 97 read_timeout=config.read_timeout, 98 ) 99 self._enabled = config.enabled 100 logger.info( 101 f"AzureBlobStorageClient initialized - " 102 f"container={config.bucket_name}" 103 ) 104 except Exception as e: 105 logger.error(f"Failed to initialize Azure Blob client: {e}") 106 raise StorageConnectionError( 107 f"Failed to connect to Azure Blob Storage: {e}" 108 ) from e
Initialize Azure Blob Storage Client
Args: config: AzureBlobConfig instance
Raises: StorageConfigError: 設定が不正な場合 StorageConnectionError: 接続に失敗した場合
114 @property 115 def container_name(self) -> str: 116 """コンテナ名(bucket_nameのエイリアス)""" 117 return self._azure_config.bucket_name
コンテナ名(bucket_nameのエイリアス)
152 async def ensure_bucket_exists(self) -> bool: 153 """コンテナの存在確認・作成""" 154 if not self._enabled: 155 return False 156 157 try: 158 container_client = self._get_container_client() 159 160 exists = await asyncio.get_event_loop().run_in_executor( 161 None, container_client.exists 162 ) 163 164 if not exists: 165 if self._azure_config.auto_create_bucket: 166 await asyncio.get_event_loop().run_in_executor( 167 None, container_client.create_container 168 ) 169 logger.info(f"Created container: {self.container_name}") 170 else: 171 logger.error( 172 f"Container {self.container_name} does not exist " 173 "and auto_create_bucket is disabled" 174 ) 175 return False 176 177 return True 178 179 except Exception as e: 180 logger.error(f"Failed to ensure container exists: {e}") 181 return False
コンテナの存在確認・作成
183 async def upload_file( 184 self, 185 file_path: str, 186 object_name: Optional[str] = None, 187 prefix: Optional[str] = None, 188 metadata: Optional[Dict[str, Any]] = None, 189 ) -> UploadResult: 190 """ 191 ファイルをAzure Blobにアップロード 192 193 Args: 194 file_path: アップロードするファイルのローカルパス 195 object_name: Blob名(省略時はファイル名が使用される) 196 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 197 metadata: 追加メタデータ(日本語はBase64エンコードされる) 198 199 Returns: 200 UploadResult: アップロード結果 201 - success: 成功/失敗 202 - object_name: アップロード先のBlob名 203 - object_url: 公開URL 204 - file_size: ファイルサイズ(bytes) 205 - content_type: MIMEタイプ 206 207 Example: 208 >>> # 基本的なアップロード 209 >>> result = await client.upload_file("/tmp/report.pdf") 210 >>> if result.success: 211 ... print(f"Uploaded: {result.object_url}") 212 >>> 213 >>> # プレフィックス指定でフォルダ構造を作成 214 >>> result = await client.upload_file( 215 ... file_path="/tmp/image.png", 216 ... prefix="uploads/2024/01", 217 ... metadata={"author": "user-123"} 218 ... ) 219 """ 220 if not self._enabled: 221 return UploadResult( 222 success=False, 223 error="Azure Blob Storage is not enabled", 224 error_code="STORAGE_NOT_ENABLED", 225 ) 226 227 try: 228 # ファイルサイズ検証 229 self._validate_file_size(file_path) 230 231 file_path_obj = Path(file_path) 232 233 # オブジェクト名決定 234 if not object_name: 235 object_name = file_path_obj.name 236 237 # URLエンコード 238 safe_object_name = self._encode_blob_name(object_name) 239 240 # プレフィックス適用 241 config_prefix = self._azure_config.prefix 242 if prefix: 243 full_object_name = f"{config_prefix}{prefix}/{safe_object_name}" 244 elif config_prefix: 245 full_object_name = f"{config_prefix}{safe_object_name}" 246 else: 247 full_object_name = safe_object_name 248 249 # コンテナ確認 250 if not await self.ensure_bucket_exists(): 251 return UploadResult( 252 success=False, 253 error="Container is not available", 254 error_code="CONTAINER_NOT_AVAILABLE", 255 ) 256 257 # Content-Type設定 258 content_type = self._get_content_type(file_path) 259 content_settings = ContentSettings(content_type=content_type) 260 261 # メタデータサニタイズ 262 safe_metadata = self._sanitize_metadata(metadata) 263 264 # アップロード実行 265 blob_client = self._get_blob_client(full_object_name) 266 file_size = file_path_obj.stat().st_size 267 268 with open(file_path_obj, 'rb') as data: 269 await asyncio.get_event_loop().run_in_executor( 270 None, 271 lambda: blob_client.upload_blob( 272 data, 273 overwrite=True, 274 content_settings=content_settings, 275 metadata=safe_metadata if safe_metadata else None, 276 ) 277 ) 278 279 # URL生成 280 blob_url = self._generate_public_url(blob_client.url) 281 282 logger.info("Successfully uploaded file to Azure Blob") 283 284 return UploadResult( 285 success=True, 286 object_name=full_object_name, 287 object_url=blob_url, 288 file_size=file_size, 289 content_type=content_type, 290 metadata={ 291 "uploaded_at": time.time(), 292 "prefix": prefix, 293 **(safe_metadata or {}), 294 }, 295 ) 296 297 except StorageOperationError: 298 raise 299 except Exception as e: 300 logger.error(f"Failed to upload file: {e}") 301 return UploadResult( 302 success=False, 303 error=f"Upload failed: {e}", 304 error_code="UPLOAD_FAILED", 305 )
ファイルをAzure Blobにアップロード
Args: file_path: アップロードするファイルのローカルパス object_name: Blob名(省略時はファイル名が使用される) prefix: オブジェクト名プレフィックス(フォルダ構造を作成) metadata: 追加メタデータ(日本語はBase64エンコードされる)
Returns: UploadResult: アップロード結果 - success: 成功/失敗 - object_name: アップロード先のBlob名 - object_url: 公開URL - file_size: ファイルサイズ(bytes) - content_type: MIMEタイプ
Example:
基本的なアップロード
result = await client.upload_file("/tmp/report.pdf") if result.success: ... print(f"Uploaded: {result.object_url}")
プレフィックス指定でフォルダ構造を作成
result = await client.upload_file( ... file_path="/tmp/image.png", ... prefix="uploads/2024/01", ... metadata={"author": "user-123"} ... )
307 async def download_file( 308 self, 309 object_name: str, 310 download_path: str, 311 ) -> DownloadResult: 312 """ 313 ファイルをダウンロード(ストリーミング) 314 315 大きなファイルでもメモリを圧迫しないよう、 316 チャンク単位でストリーミングダウンロードします。 317 318 Args: 319 object_name: ダウンロードするBlob名 320 download_path: ダウンロード先のローカルパス 321 322 Returns: 323 DownloadResult: ダウンロード結果 324 - success: 成功/失敗 325 - local_path: ダウンロード先パス 326 - file_size: ファイルサイズ(bytes) 327 - error: エラーメッセージ(失敗時) 328 329 Example: 330 >>> # Blobをダウンロード 331 >>> result = await client.download_file( 332 ... object_name="uploads/2024/01/report.pdf", 333 ... download_path="/tmp/downloaded_report.pdf" 334 ... ) 335 >>> if result.success: 336 ... print(f"Downloaded to: {result.local_path}") 337 ... print(f"Size: {result.file_size} bytes") 338 >>> else: 339 ... print(f"Error: {result.error}") 340 341 Note: 342 ダウンロード先ディレクトリが存在しない場合は自動作成されます。 343 """ 344 if not self._enabled: 345 return DownloadResult( 346 success=False, 347 error="Azure Blob Storage is not enabled", 348 error_code="STORAGE_NOT_ENABLED", 349 ) 350 351 try: 352 blob_client = self._get_blob_client(object_name) 353 354 # 存在確認 355 exists = await asyncio.get_event_loop().run_in_executor( 356 None, blob_client.exists 357 ) 358 if not exists: 359 return DownloadResult( 360 success=False, 361 object_name=object_name, 362 error=f"Blob not found: {object_name}", 363 error_code="BLOB_NOT_FOUND", 364 ) 365 366 # ダウンロード先ディレクトリ作成 367 download_path_obj = Path(download_path) 368 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 369 370 # ストリーミングダウンロード(メモリ効率化) 371 def _streaming_download(): 372 with open(download_path_obj, 'wb') as f: 373 download_stream = blob_client.download_blob() 374 # chunks()でチャンク単位でダウンロード 375 for chunk in download_stream.chunks(): 376 f.write(chunk) 377 378 await asyncio.get_event_loop().run_in_executor( 379 None, _streaming_download 380 ) 381 382 file_size = download_path_obj.stat().st_size 383 logger.info("Successfully downloaded blob") 384 385 return DownloadResult( 386 success=True, 387 object_name=object_name, 388 local_path=str(download_path_obj), 389 file_size=file_size, 390 ) 391 392 except Exception as e: 393 logger.error(f"Failed to download blob: {e}") 394 return DownloadResult( 395 success=False, 396 object_name=object_name, 397 error=f"Download failed: {e}", 398 error_code="DOWNLOAD_FAILED", 399 )
ファイルをダウンロード(ストリーミング)
大きなファイルでもメモリを圧迫しないよう、 チャンク単位でストリーミングダウンロードします。
Args: object_name: ダウンロードするBlob名 download_path: ダウンロード先のローカルパス
Returns: DownloadResult: ダウンロード結果 - success: 成功/失敗 - local_path: ダウンロード先パス - file_size: ファイルサイズ(bytes) - error: エラーメッセージ(失敗時)
Example:
Blobをダウンロード
result = await client.download_file( ... object_name="uploads/2024/01/report.pdf", ... download_path="/tmp/downloaded_report.pdf" ... ) if result.success: ... print(f"Downloaded to: {result.local_path}") ... print(f"Size: {result.file_size} bytes") else: ... print(f"Error: {result.error}")
Note: ダウンロード先ディレクトリが存在しない場合は自動作成されます。
401 async def list_objects( 402 self, 403 prefix: str = "", 404 max_results: Optional[int] = None, 405 ) -> ListResult: 406 """ 407 オブジェクト一覧を取得 408 409 指定したプレフィックスに一致するBlobの一覧を取得します。 410 411 Args: 412 prefix: プレフィックスフィルタ(フォルダパスとして機能) 413 max_results: 最大取得数(省略時は全件取得) 414 415 Returns: 416 ListResult: 一覧結果 417 - success: 成功/失敗 418 - objects: ObjectInfoのリスト 419 - count: 取得件数 420 421 Example: 422 >>> # 全オブジェクト一覧 423 >>> result = await client.list_objects() 424 >>> print(f"Total: {result.count} objects") 425 >>> 426 >>> # プレフィックスでフィルタ 427 >>> result = await client.list_objects( 428 ... prefix="uploads/2024/", 429 ... max_results=100 430 ... ) 431 >>> for obj in result.objects: 432 ... print(f"{obj.name}: {obj.size} bytes") 433 """ 434 if not self._enabled: 435 return ListResult( 436 success=False, 437 error="Azure Blob Storage is not enabled", 438 error_code="STORAGE_NOT_ENABLED", 439 ) 440 441 try: 442 container_client = self._get_container_client() 443 444 # プレフィックス適用 445 full_prefix = f"{self._azure_config.prefix}{prefix}" 446 447 blob_list = await asyncio.get_event_loop().run_in_executor( 448 None, 449 lambda: list(container_client.list_blobs(name_starts_with=full_prefix)) 450 ) 451 452 objects = [] 453 for blob in blob_list: 454 if max_results and len(objects) >= max_results: 455 break 456 objects.append(ObjectInfo( 457 name=blob.name, 458 size=blob.size, 459 last_modified=blob.last_modified.isoformat() if blob.last_modified else None, 460 content_type=blob.content_settings.content_type if blob.content_settings else None, 461 )) 462 463 logger.info(f"Listed {len(objects)} blobs") 464 465 return ListResult( 466 success=True, 467 objects=objects, 468 count=len(objects), 469 prefix=prefix, 470 ) 471 472 except Exception as e: 473 logger.error(f"Failed to list blobs: {e}") 474 return ListResult( 475 success=False, 476 error=f"List failed: {e}", 477 error_code="LIST_FAILED", 478 )
オブジェクト一覧を取得
指定したプレフィックスに一致するBlobの一覧を取得します。
Args: prefix: プレフィックスフィルタ(フォルダパスとして機能) max_results: 最大取得数(省略時は全件取得)
Returns: ListResult: 一覧結果 - success: 成功/失敗 - objects: ObjectInfoのリスト - count: 取得件数
Example:
全オブジェクト一覧
result = await client.list_objects() print(f"Total: {result.count} objects")
プレフィックスでフィルタ
result = await client.list_objects( ... prefix="uploads/2024/", ... max_results=100 ... ) for obj in result.objects: ... print(f"{obj.name}: {obj.size} bytes")
480 async def delete_object(self, object_name: str) -> bool: 481 """ 482 オブジェクトを削除 483 484 Args: 485 object_name: 削除するBlob名 486 487 Returns: 488 成功した場合True 489 """ 490 if not self._enabled: 491 return False 492 493 try: 494 blob_client = self._get_blob_client(object_name) 495 496 await asyncio.get_event_loop().run_in_executor( 497 None, 498 lambda: blob_client.delete_blob(delete_snapshots="include") 499 ) 500 501 logger.info("Successfully deleted blob") 502 return True 503 504 except Exception as e: 505 logger.error(f"Failed to delete blob: {e}") 506 return False
オブジェクトを削除
Args: object_name: 削除するBlob名
Returns: 成功した場合True
508 async def object_exists(self, object_name: str) -> bool: 509 """ 510 オブジェクトの存在確認 511 512 Args: 513 object_name: 確認するBlob名 514 515 Returns: 516 存在する場合True 517 """ 518 if not self._enabled: 519 return False 520 521 try: 522 blob_client = self._get_blob_client(object_name) 523 return await asyncio.get_event_loop().run_in_executor( 524 None, blob_client.exists 525 ) 526 except Exception as e: 527 logger.error(f"Failed to check blob existence: {e}") 528 return False
オブジェクトの存在確認
Args: object_name: 確認するBlob名
Returns: 存在する場合True
530 async def get_object_url( 531 self, 532 object_name: str, 533 expires_in: Optional[int] = None, 534 ) -> str: 535 """ 536 オブジェクトのURLを取得 537 538 Args: 539 object_name: Blob名 540 expires_in: SAS有効期限(秒)- 未実装、将来拡張用 541 542 Returns: 543 オブジェクトURL 544 """ 545 if not self._enabled: 546 return "" 547 548 blob_client = self._get_blob_client(object_name) 549 return self._generate_public_url(blob_client.url)
オブジェクトのURLを取得
Args: object_name: Blob名 expires_in: SAS有効期限(秒)- 未実装、将来拡張用
Returns: オブジェクトURL
569 async def download_objects_by_prefix( 570 self, 571 prefix: str, 572 download_dir: str, 573 ) -> Dict[str, Any]: 574 """ 575 プレフィックスに一致するオブジェクトをすべてダウンロード 576 577 Args: 578 prefix: プレフィックス 579 download_dir: ダウンロード先ディレクトリ 580 581 Returns: 582 ダウンロード結果サマリー 583 584 Security: 585 パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 586 ダウンロード先がdownload_dir配下であることを検証します。 587 """ 588 list_result = await self.list_objects(prefix) 589 if not list_result.success: 590 return { 591 "success": False, 592 "error": list_result.error, 593 "downloaded": [], 594 "failed": [], 595 } 596 597 download_dir_obj = Path(download_dir).resolve() 598 downloaded = [] 599 failed = [] 600 skipped = [] 601 602 for obj in list_result.objects: 603 # プレフィックス除去 604 relative_path = obj.name 605 if obj.name.startswith(prefix): 606 relative_path = obj.name[len(prefix):] 607 608 # パスサニタイズ(パストラバーサル防止) 609 relative_path = self._sanitize_path(relative_path) 610 611 if not relative_path: 612 skipped.append({ 613 "object_name": obj.name, 614 "reason": "Invalid path after sanitization", 615 }) 616 continue 617 618 download_path = (download_dir_obj / relative_path).resolve() 619 620 # パストラバーサル検証: download_dir配下であることを確認 621 try: 622 download_path.relative_to(download_dir_obj) 623 except ValueError: 624 logger.warning( 625 f"Path traversal attempt detected, skipping object" 626 ) 627 skipped.append({ 628 "object_name": obj.name, 629 "reason": "Path traversal detected", 630 }) 631 continue 632 633 result = await self.download_file(obj.name, str(download_path)) 634 if result.success: 635 downloaded.append({ 636 "object_name": obj.name, 637 "local_path": result.local_path, 638 "file_size": result.file_size, 639 }) 640 else: 641 failed.append({ 642 "object_name": obj.name, 643 "error": result.error, 644 }) 645 646 return { 647 "success": True, 648 "downloaded": downloaded, 649 "failed": failed, 650 "skipped": skipped, 651 "success_count": len(downloaded), 652 "failed_count": len(failed), 653 "skipped_count": len(skipped), 654 }
プレフィックスに一致するオブジェクトをすべてダウンロード
Args: prefix: プレフィックス download_dir: ダウンロード先ディレクトリ
Returns: ダウンロード結果サマリー
Security: パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 ダウンロード先がdownload_dir配下であることを検証します。
656 async def close(self) -> None: 657 """クライアントリソースをクローズ 658 659 Note: 660 例外が発生してもリソースを確実にクリーンアップします。 661 """ 662 try: 663 if self._blob_service_client: 664 # BlobServiceClientはclose()メソッドを持たないが、 665 # 将来のバージョンで追加される可能性があるため確認 666 if hasattr(self._blob_service_client, 'close'): 667 self._blob_service_client.close() 668 except Exception as e: 669 logger.warning(f"Error closing Azure Blob client: {e}") 670 finally: 671 self._blob_service_client = None 672 self._enabled = False 673 logger.debug("AzureBlobStorageClient closed")
クライアントリソースをクローズ
Note: 例外が発生してもリソースを確実にクリーンアップします。
36class S3StorageClient(StorageClientBase): 37 """ 38 AWS S3ストレージクライアント(スタブ実装) 39 40 Note: 41 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 42 43 Example: 44 config = S3Config( 45 bucket_name="my-bucket", 46 aws_access_key_id="AKIA...", 47 aws_secret_access_key="...", 48 region_name="ap-northeast-1", 49 ) 50 client = S3StorageClient(config) 51 result = await client.upload_file("/path/to/file.txt") 52 """ 53 54 def __init__(self, config: S3Config): 55 """ 56 Initialize AWS S3 Storage Client 57 58 Args: 59 config: S3Config instance 60 61 Raises: 62 StorageConfigError: 設定が不正な場合 63 """ 64 super().__init__(config) 65 66 if not S3_AVAILABLE: 67 raise StorageConfigError( 68 "boto3 package is not installed. " 69 "Install with: pip install boto3" 70 ) 71 72 if not config.bucket_name: 73 raise StorageConfigError("S3 bucket_name is required") 74 75 if not config.aws_access_key_id or not config.aws_secret_access_key: 76 raise StorageConfigError( 77 "S3 requires aws_access_key_id and aws_secret_access_key" 78 ) 79 80 self._s3_config: S3Config = config 81 self._s3_client = None 82 self._s3_resource = None 83 84 try: 85 session = boto3.Session( 86 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 87 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 88 region_name=config.region_name, 89 ) 90 91 client_kwargs = {} 92 if config.endpoint_url: 93 client_kwargs['endpoint_url'] = config.endpoint_url 94 95 self._s3_client = session.client('s3', **client_kwargs) 96 self._s3_resource = session.resource('s3', **client_kwargs) 97 self._enabled = config.enabled 98 99 logger.warning( 100 "S3StorageClient is a STUB implementation. " 101 "upload_file, download_file等のメソッドはNOT_IMPLEMENTEDを返します。" 102 ) 103 logger.info( 104 f"S3StorageClient initialized - bucket={config.bucket_name}, " 105 f"region={config.region_name}" 106 ) 107 108 except Exception as e: 109 logger.error(f"Failed to initialize S3 client: {e}") 110 raise StorageConfigError(f"Failed to connect to S3: {e}") from e 111 112 @property 113 def provider(self) -> StorageProvider: 114 return StorageProvider.AWS_S3 115 116 @property 117 def bucket_name(self) -> str: 118 return self._s3_config.bucket_name 119 120 async def ensure_bucket_exists(self) -> bool: 121 """バケットの存在確認・作成""" 122 if not self._enabled: 123 return False 124 125 try: 126 # asyncioブロック防止: run_in_executorでオフロード 127 await asyncio.get_event_loop().run_in_executor( 128 None, 129 lambda: self._s3_client.head_bucket(Bucket=self.bucket_name) 130 ) 131 return True 132 except ClientError as e: 133 error_code = e.response.get('Error', {}).get('Code', '') 134 if error_code == '404': 135 if self._s3_config.auto_create_bucket: 136 try: 137 create_params = {'Bucket': self.bucket_name} 138 if self._s3_config.region_name != 'us-east-1': 139 create_params['CreateBucketConfiguration'] = { 140 'LocationConstraint': self._s3_config.region_name 141 } 142 await asyncio.get_event_loop().run_in_executor( 143 None, 144 lambda: self._s3_client.create_bucket(**create_params) 145 ) 146 logger.info(f"Created S3 bucket: {self.bucket_name}") 147 return True 148 except Exception as create_error: 149 logger.error(f"Failed to create bucket: {create_error}") 150 return False 151 else: 152 logger.error(f"Bucket {self.bucket_name} does not exist") 153 return False 154 logger.error(f"Failed to check bucket: {e}") 155 return False 156 except Exception as e: 157 logger.error(f"Failed to ensure bucket exists: {e}") 158 return False 159 160 async def upload_file( 161 self, 162 file_path: str, 163 object_name: Optional[str] = None, 164 prefix: Optional[str] = None, 165 metadata: Optional[Dict[str, Any]] = None, 166 ) -> UploadResult: 167 """ファイルをS3にアップロード""" 168 # TODO: 本格実装 169 return UploadResult( 170 success=False, 171 error="S3 upload not yet implemented", 172 error_code="NOT_IMPLEMENTED", 173 ) 174 175 async def download_file( 176 self, 177 object_name: str, 178 download_path: str, 179 ) -> DownloadResult: 180 """ファイルをダウンロード""" 181 # TODO: 本格実装 182 return DownloadResult( 183 success=False, 184 error="S3 download not yet implemented", 185 error_code="NOT_IMPLEMENTED", 186 ) 187 188 async def list_objects( 189 self, 190 prefix: str = "", 191 max_results: Optional[int] = None, 192 ) -> ListResult: 193 """オブジェクト一覧を取得""" 194 # TODO: 本格実装 195 return ListResult( 196 success=False, 197 error="S3 list not yet implemented", 198 error_code="NOT_IMPLEMENTED", 199 ) 200 201 async def delete_object(self, object_name: str) -> bool: 202 """オブジェクトを削除""" 203 # TODO: 本格実装 204 return False 205 206 async def object_exists(self, object_name: str) -> bool: 207 """オブジェクトの存在確認""" 208 # TODO: 本格実装 209 return False 210 211 async def get_object_url( 212 self, 213 object_name: str, 214 expires_in: Optional[int] = None, 215 ) -> str: 216 """署名付きURLを取得""" 217 # TODO: 本格実装 218 return "" 219 220 async def close(self) -> None: 221 """クライアントリソースをクローズ""" 222 try: 223 # boto3クライアントのコネクションプールを適切にクローズ 224 if self._s3_client: 225 self._s3_client.close() 226 if self._s3_resource and hasattr(self._s3_resource, 'meta'): 227 self._s3_resource.meta.client.close() 228 except Exception as e: 229 logger.warning(f"Error during S3 client close: {e}") 230 finally: 231 self._s3_client = None 232 self._s3_resource = None 233 self._enabled = False 234 logger.debug("S3StorageClient closed")
AWS S3ストレージクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。
Example: config = S3Config( bucket_name="my-bucket", aws_access_key_id="AKIA...", aws_secret_access_key="...", region_name="ap-northeast-1", ) client = S3StorageClient(config) result = await client.upload_file("/path/to/file.txt")
54 def __init__(self, config: S3Config): 55 """ 56 Initialize AWS S3 Storage Client 57 58 Args: 59 config: S3Config instance 60 61 Raises: 62 StorageConfigError: 設定が不正な場合 63 """ 64 super().__init__(config) 65 66 if not S3_AVAILABLE: 67 raise StorageConfigError( 68 "boto3 package is not installed. " 69 "Install with: pip install boto3" 70 ) 71 72 if not config.bucket_name: 73 raise StorageConfigError("S3 bucket_name is required") 74 75 if not config.aws_access_key_id or not config.aws_secret_access_key: 76 raise StorageConfigError( 77 "S3 requires aws_access_key_id and aws_secret_access_key" 78 ) 79 80 self._s3_config: S3Config = config 81 self._s3_client = None 82 self._s3_resource = None 83 84 try: 85 session = boto3.Session( 86 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 87 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 88 region_name=config.region_name, 89 ) 90 91 client_kwargs = {} 92 if config.endpoint_url: 93 client_kwargs['endpoint_url'] = config.endpoint_url 94 95 self._s3_client = session.client('s3', **client_kwargs) 96 self._s3_resource = session.resource('s3', **client_kwargs) 97 self._enabled = config.enabled 98 99 logger.warning( 100 "S3StorageClient is a STUB implementation. " 101 "upload_file, download_file等のメソッドはNOT_IMPLEMENTEDを返します。" 102 ) 103 logger.info( 104 f"S3StorageClient initialized - bucket={config.bucket_name}, " 105 f"region={config.region_name}" 106 ) 107 108 except Exception as e: 109 logger.error(f"Failed to initialize S3 client: {e}") 110 raise StorageConfigError(f"Failed to connect to S3: {e}") from e
Initialize AWS S3 Storage Client
Args: config: S3Config instance
Raises: StorageConfigError: 設定が不正な場合
120 async def ensure_bucket_exists(self) -> bool: 121 """バケットの存在確認・作成""" 122 if not self._enabled: 123 return False 124 125 try: 126 # asyncioブロック防止: run_in_executorでオフロード 127 await asyncio.get_event_loop().run_in_executor( 128 None, 129 lambda: self._s3_client.head_bucket(Bucket=self.bucket_name) 130 ) 131 return True 132 except ClientError as e: 133 error_code = e.response.get('Error', {}).get('Code', '') 134 if error_code == '404': 135 if self._s3_config.auto_create_bucket: 136 try: 137 create_params = {'Bucket': self.bucket_name} 138 if self._s3_config.region_name != 'us-east-1': 139 create_params['CreateBucketConfiguration'] = { 140 'LocationConstraint': self._s3_config.region_name 141 } 142 await asyncio.get_event_loop().run_in_executor( 143 None, 144 lambda: self._s3_client.create_bucket(**create_params) 145 ) 146 logger.info(f"Created S3 bucket: {self.bucket_name}") 147 return True 148 except Exception as create_error: 149 logger.error(f"Failed to create bucket: {create_error}") 150 return False 151 else: 152 logger.error(f"Bucket {self.bucket_name} does not exist") 153 return False 154 logger.error(f"Failed to check bucket: {e}") 155 return False 156 except Exception as e: 157 logger.error(f"Failed to ensure bucket exists: {e}") 158 return False
バケットの存在確認・作成
160 async def upload_file( 161 self, 162 file_path: str, 163 object_name: Optional[str] = None, 164 prefix: Optional[str] = None, 165 metadata: Optional[Dict[str, Any]] = None, 166 ) -> UploadResult: 167 """ファイルをS3にアップロード""" 168 # TODO: 本格実装 169 return UploadResult( 170 success=False, 171 error="S3 upload not yet implemented", 172 error_code="NOT_IMPLEMENTED", 173 )
ファイルをS3にアップロード
175 async def download_file( 176 self, 177 object_name: str, 178 download_path: str, 179 ) -> DownloadResult: 180 """ファイルをダウンロード""" 181 # TODO: 本格実装 182 return DownloadResult( 183 success=False, 184 error="S3 download not yet implemented", 185 error_code="NOT_IMPLEMENTED", 186 )
ファイルをダウンロード
188 async def list_objects( 189 self, 190 prefix: str = "", 191 max_results: Optional[int] = None, 192 ) -> ListResult: 193 """オブジェクト一覧を取得""" 194 # TODO: 本格実装 195 return ListResult( 196 success=False, 197 error="S3 list not yet implemented", 198 error_code="NOT_IMPLEMENTED", 199 )
オブジェクト一覧を取得
201 async def delete_object(self, object_name: str) -> bool: 202 """オブジェクトを削除""" 203 # TODO: 本格実装 204 return False
オブジェクトを削除
206 async def object_exists(self, object_name: str) -> bool: 207 """オブジェクトの存在確認""" 208 # TODO: 本格実装 209 return False
オブジェクトの存在確認
211 async def get_object_url( 212 self, 213 object_name: str, 214 expires_in: Optional[int] = None, 215 ) -> str: 216 """署名付きURLを取得""" 217 # TODO: 本格実装 218 return ""
署名付きURLを取得
220 async def close(self) -> None: 221 """クライアントリソースをクローズ""" 222 try: 223 # boto3クライアントのコネクションプールを適切にクローズ 224 if self._s3_client: 225 self._s3_client.close() 226 if self._s3_resource and hasattr(self._s3_resource, 'meta'): 227 self._s3_resource.meta.client.close() 228 except Exception as e: 229 logger.warning(f"Error during S3 client close: {e}") 230 finally: 231 self._s3_client = None 232 self._s3_resource = None 233 self._enabled = False 234 logger.debug("S3StorageClient closed")
クライアントリソースをクローズ
36class GCSStorageClient(StorageClientBase): 37 """ 38 Google Cloud Storageクライアント(スタブ実装) 39 40 Note: 41 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 42 43 Example: 44 config = GCSConfig( 45 bucket_name="my-bucket", 46 project_id="my-project", 47 credentials_path="/path/to/service-account.json", 48 ) 49 client = GCSStorageClient(config) 50 result = await client.upload_file("/path/to/file.txt") 51 """ 52 53 def __init__(self, config: GCSConfig): 54 """ 55 Initialize Google Cloud Storage Client 56 57 Args: 58 config: GCSConfig instance 59 60 Raises: 61 StorageConfigError: 設定が不正な場合 62 """ 63 super().__init__(config) 64 65 if not GCS_AVAILABLE: 66 raise StorageConfigError( 67 "google-cloud-storage package is not installed. " 68 "Install with: pip install google-cloud-storage" 69 ) 70 71 if not config.bucket_name: 72 raise StorageConfigError("GCS bucket_name is required") 73 74 if not config.project_id: 75 raise StorageConfigError("GCS project_id is required") 76 77 self._gcs_config: GCSConfig = config 78 self._gcs_client = None 79 self._bucket = None 80 81 try: 82 # 認証情報の設定 83 if config.credentials_path: 84 credentials = service_account.Credentials.from_service_account_file( 85 config.credentials_path 86 ) 87 self._gcs_client = gcs_storage.Client( 88 project=config.project_id, 89 credentials=credentials, 90 ) 91 elif config.credentials_json: 92 import json 93 credentials_info = json.loads(config.credentials_json) 94 credentials = service_account.Credentials.from_service_account_info( 95 credentials_info 96 ) 97 self._gcs_client = gcs_storage.Client( 98 project=config.project_id, 99 credentials=credentials, 100 ) 101 else: 102 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 103 self._gcs_client = gcs_storage.Client(project=config.project_id) 104 105 self._enabled = config.enabled 106 107 logger.warning( 108 "GCSStorageClient is a STUB implementation. " 109 "upload_file, download_file等のメソッドはNOT_IMPLEMENTEDを返します。" 110 ) 111 logger.info( 112 f"GCSStorageClient initialized - bucket={config.bucket_name}, " 113 f"project={config.project_id}" 114 ) 115 116 except Exception as e: 117 logger.error(f"Failed to initialize GCS client: {e}") 118 raise StorageConfigError(f"Failed to connect to GCS: {e}") from e 119 120 @property 121 def provider(self) -> StorageProvider: 122 return StorageProvider.GCS 123 124 @property 125 def bucket_name(self) -> str: 126 return self._gcs_config.bucket_name 127 128 async def ensure_bucket_exists(self) -> bool: 129 """バケットの存在確認・作成""" 130 if not self._enabled: 131 return False 132 133 try: 134 self._bucket = self._gcs_client.bucket(self.bucket_name) 135 136 # asyncioブロック防止: run_in_executorでオフロード 137 exists = await asyncio.get_event_loop().run_in_executor( 138 None, self._bucket.exists 139 ) 140 if exists: 141 return True 142 143 if self._gcs_config.auto_create_bucket: 144 self._bucket = await asyncio.get_event_loop().run_in_executor( 145 None, 146 lambda: self._gcs_client.create_bucket(self.bucket_name) 147 ) 148 logger.info(f"Created GCS bucket: {self.bucket_name}") 149 return True 150 else: 151 logger.error(f"Bucket {self.bucket_name} does not exist") 152 return False 153 154 except Exception as e: 155 logger.error(f"Failed to ensure bucket exists: {e}") 156 return False 157 158 async def upload_file( 159 self, 160 file_path: str, 161 object_name: Optional[str] = None, 162 prefix: Optional[str] = None, 163 metadata: Optional[Dict[str, Any]] = None, 164 ) -> UploadResult: 165 """ファイルをGCSにアップロード""" 166 # TODO: 本格実装 167 return UploadResult( 168 success=False, 169 error="GCS upload not yet implemented", 170 error_code="NOT_IMPLEMENTED", 171 ) 172 173 async def download_file( 174 self, 175 object_name: str, 176 download_path: str, 177 ) -> DownloadResult: 178 """ファイルをダウンロード""" 179 # TODO: 本格実装 180 return DownloadResult( 181 success=False, 182 error="GCS download not yet implemented", 183 error_code="NOT_IMPLEMENTED", 184 ) 185 186 async def list_objects( 187 self, 188 prefix: str = "", 189 max_results: Optional[int] = None, 190 ) -> ListResult: 191 """オブジェクト一覧を取得""" 192 # TODO: 本格実装 193 return ListResult( 194 success=False, 195 error="GCS list not yet implemented", 196 error_code="NOT_IMPLEMENTED", 197 ) 198 199 async def delete_object(self, object_name: str) -> bool: 200 """オブジェクトを削除""" 201 # TODO: 本格実装 202 return False 203 204 async def object_exists(self, object_name: str) -> bool: 205 """オブジェクトの存在確認""" 206 # TODO: 本格実装 207 return False 208 209 async def get_object_url( 210 self, 211 object_name: str, 212 expires_in: Optional[int] = None, 213 ) -> str: 214 """署名付きURLを取得""" 215 # TODO: 本格実装 216 return "" 217 218 async def close(self) -> None: 219 """クライアントリソースをクローズ 220 221 Note: 222 例外が発生してもリソースを確実にクリーンアップします。 223 """ 224 try: 225 if self._gcs_client: 226 self._gcs_client.close() 227 except Exception as e: 228 logger.warning(f"Error closing GCS client: {e}") 229 finally: 230 self._gcs_client = None 231 self._bucket = None 232 self._enabled = False 233 logger.debug("GCSStorageClient closed")
Google Cloud Storageクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。
Example: config = GCSConfig( bucket_name="my-bucket", project_id="my-project", credentials_path="/path/to/service-account.json", ) client = GCSStorageClient(config) result = await client.upload_file("/path/to/file.txt")
53 def __init__(self, config: GCSConfig): 54 """ 55 Initialize Google Cloud Storage Client 56 57 Args: 58 config: GCSConfig instance 59 60 Raises: 61 StorageConfigError: 設定が不正な場合 62 """ 63 super().__init__(config) 64 65 if not GCS_AVAILABLE: 66 raise StorageConfigError( 67 "google-cloud-storage package is not installed. " 68 "Install with: pip install google-cloud-storage" 69 ) 70 71 if not config.bucket_name: 72 raise StorageConfigError("GCS bucket_name is required") 73 74 if not config.project_id: 75 raise StorageConfigError("GCS project_id is required") 76 77 self._gcs_config: GCSConfig = config 78 self._gcs_client = None 79 self._bucket = None 80 81 try: 82 # 認証情報の設定 83 if config.credentials_path: 84 credentials = service_account.Credentials.from_service_account_file( 85 config.credentials_path 86 ) 87 self._gcs_client = gcs_storage.Client( 88 project=config.project_id, 89 credentials=credentials, 90 ) 91 elif config.credentials_json: 92 import json 93 credentials_info = json.loads(config.credentials_json) 94 credentials = service_account.Credentials.from_service_account_info( 95 credentials_info 96 ) 97 self._gcs_client = gcs_storage.Client( 98 project=config.project_id, 99 credentials=credentials, 100 ) 101 else: 102 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 103 self._gcs_client = gcs_storage.Client(project=config.project_id) 104 105 self._enabled = config.enabled 106 107 logger.warning( 108 "GCSStorageClient is a STUB implementation. " 109 "upload_file, download_file等のメソッドはNOT_IMPLEMENTEDを返します。" 110 ) 111 logger.info( 112 f"GCSStorageClient initialized - bucket={config.bucket_name}, " 113 f"project={config.project_id}" 114 ) 115 116 except Exception as e: 117 logger.error(f"Failed to initialize GCS client: {e}") 118 raise StorageConfigError(f"Failed to connect to GCS: {e}") from e
Initialize Google Cloud Storage Client
Args: config: GCSConfig instance
Raises: StorageConfigError: 設定が不正な場合
128 async def ensure_bucket_exists(self) -> bool: 129 """バケットの存在確認・作成""" 130 if not self._enabled: 131 return False 132 133 try: 134 self._bucket = self._gcs_client.bucket(self.bucket_name) 135 136 # asyncioブロック防止: run_in_executorでオフロード 137 exists = await asyncio.get_event_loop().run_in_executor( 138 None, self._bucket.exists 139 ) 140 if exists: 141 return True 142 143 if self._gcs_config.auto_create_bucket: 144 self._bucket = await asyncio.get_event_loop().run_in_executor( 145 None, 146 lambda: self._gcs_client.create_bucket(self.bucket_name) 147 ) 148 logger.info(f"Created GCS bucket: {self.bucket_name}") 149 return True 150 else: 151 logger.error(f"Bucket {self.bucket_name} does not exist") 152 return False 153 154 except Exception as e: 155 logger.error(f"Failed to ensure bucket exists: {e}") 156 return False
バケットの存在確認・作成
158 async def upload_file( 159 self, 160 file_path: str, 161 object_name: Optional[str] = None, 162 prefix: Optional[str] = None, 163 metadata: Optional[Dict[str, Any]] = None, 164 ) -> UploadResult: 165 """ファイルをGCSにアップロード""" 166 # TODO: 本格実装 167 return UploadResult( 168 success=False, 169 error="GCS upload not yet implemented", 170 error_code="NOT_IMPLEMENTED", 171 )
ファイルをGCSにアップロード
173 async def download_file( 174 self, 175 object_name: str, 176 download_path: str, 177 ) -> DownloadResult: 178 """ファイルをダウンロード""" 179 # TODO: 本格実装 180 return DownloadResult( 181 success=False, 182 error="GCS download not yet implemented", 183 error_code="NOT_IMPLEMENTED", 184 )
ファイルをダウンロード
186 async def list_objects( 187 self, 188 prefix: str = "", 189 max_results: Optional[int] = None, 190 ) -> ListResult: 191 """オブジェクト一覧を取得""" 192 # TODO: 本格実装 193 return ListResult( 194 success=False, 195 error="GCS list not yet implemented", 196 error_code="NOT_IMPLEMENTED", 197 )
オブジェクト一覧を取得
199 async def delete_object(self, object_name: str) -> bool: 200 """オブジェクトを削除""" 201 # TODO: 本格実装 202 return False
オブジェクトを削除
204 async def object_exists(self, object_name: str) -> bool: 205 """オブジェクトの存在確認""" 206 # TODO: 本格実装 207 return False
オブジェクトの存在確認
209 async def get_object_url( 210 self, 211 object_name: str, 212 expires_in: Optional[int] = None, 213 ) -> str: 214 """署名付きURLを取得""" 215 # TODO: 本格実装 216 return ""
署名付きURLを取得
218 async def close(self) -> None: 219 """クライアントリソースをクローズ 220 221 Note: 222 例外が発生してもリソースを確実にクリーンアップします。 223 """ 224 try: 225 if self._gcs_client: 226 self._gcs_client.close() 227 except Exception as e: 228 logger.warning(f"Error closing GCS client: {e}") 229 finally: 230 self._gcs_client = None 231 self._bucket = None 232 self._enabled = False 233 logger.debug("GCSStorageClient closed")
クライアントリソースをクローズ
Note: 例外が発生してもリソースを確実にクリーンアップします。
52@dataclass 53class UploadResult: 54 """アップロード結果""" 55 success: bool 56 object_name: str = "" 57 object_url: str = "" 58 file_size: int = 0 59 content_type: str = "" 60 error: Optional[str] = None 61 error_code: Optional[str] = None 62 metadata: Dict[str, Any] = field(default_factory=dict)
アップロード結果
65@dataclass 66class DownloadResult: 67 """ダウンロード結果""" 68 success: bool 69 object_name: str = "" 70 local_path: str = "" 71 file_size: int = 0 72 error: Optional[str] = None 73 error_code: Optional[str] = None
ダウンロード結果
76@dataclass 77class ObjectInfo: 78 """オブジェクト情報""" 79 name: str 80 size: int 81 last_modified: Optional[str] = None 82 content_type: Optional[str] = None 83 metadata: Dict[str, Any] = field(default_factory=dict)
オブジェクト情報
86@dataclass 87class ListResult: 88 """一覧取得結果""" 89 success: bool 90 objects: List[ObjectInfo] = field(default_factory=list) 91 count: int = 0 92 prefix: str = "" 93 error: Optional[str] = None 94 error_code: Optional[str] = None
一覧取得結果
21class SecurityProvider(str, Enum): 22 """セキュリティプロバイダー種別""" 23 AZURE = "azure" # Azure Content Safety + Language Service 24 AWS = "aws" # AWS Comprehend 25 GCP = "gcp" # Google Cloud DLP
セキュリティプロバイダー種別
28class ContentCategory(str, Enum): 29 """有害コンテンツカテゴリ(プロバイダー共通)""" 30 HATE = "hate" 31 SEXUAL = "sexual" 32 SELF_HARM = "self_harm" 33 VIOLENCE = "violence" 34 PROFANITY = "profanity" 35 INSULT = "insult" 36 THREAT = "threat"
有害コンテンツカテゴリ(プロバイダー共通)
39class PIICategory(str, Enum): 40 """PIIカテゴリ(プロバイダー共通・主要なもの)""" 41 # 個人識別情報 42 PERSON_NAME = "person_name" 43 EMAIL = "email" 44 PHONE_NUMBER = "phone_number" 45 ADDRESS = "address" 46 AGE = "age" 47 DATE_OF_BIRTH = "date_of_birth" 48 49 # 政府発行ID 50 NATIONAL_ID = "national_id" 51 PASSPORT_NUMBER = "passport_number" 52 DRIVERS_LICENSE = "drivers_license" 53 SOCIAL_SECURITY = "social_security" 54 TAX_ID = "tax_id" 55 56 # 金融情報 57 CREDIT_CARD = "credit_card" 58 BANK_ACCOUNT = "bank_account" 59 SWIFT_CODE = "swift_code" 60 IBAN = "iban" 61 62 # 医療情報 63 MEDICAL_RECORD = "medical_record" 64 HEALTH_INSURANCE = "health_insurance" 65 66 # 技術情報 67 IP_ADDRESS = "ip_address" 68 MAC_ADDRESS = "mac_address" 69 URL = "url" 70 71 # 認証情報 72 PASSWORD = "password" 73 API_KEY = "api_key" 74 AWS_ACCESS_KEY = "aws_access_key" 75 CONNECTION_STRING = "connection_string" 76 77 # その他 78 OTHER = "other"
PIIカテゴリ(プロバイダー共通・主要なもの)
Security操作の基底エラー
Security設定エラー
91class SecurityAPIError(SecurityError): 92 """Security API呼び出しエラー""" 93 94 def __init__(self, message: str, status_code: Optional[int] = None, error_code: str = "API_ERROR"): 95 super().__init__(message) 96 self.status_code = status_code 97 self.error_code = error_code
Security API呼び出しエラー
104@dataclass 105class ContentModerationResult: 106 """コンテンツモデレーション結果""" 107 blocked: bool 108 categories: Dict[ContentCategory, int] = field(default_factory=dict) # category -> severity (0-6) 109 threshold: int = 2 110 raw_response: Dict[str, Any] = field(default_factory=dict) 111 error: Optional[str] = None 112 error_code: Optional[str] = None
コンテンツモデレーション結果
115@dataclass 116class PromptShieldResult: 117 """プロンプトシールド(Jailbreak攻撃検出)結果""" 118 attack_detected: bool 119 attack_type: Optional[str] = None # jailbreak, injection, etc. 120 confidence: float = 0.0 121 raw_response: Dict[str, Any] = field(default_factory=dict) 122 error: Optional[str] = None 123 error_code: Optional[str] = None
プロンプトシールド(Jailbreak攻撃検出)結果
126@dataclass 127class PIIEntity: 128 """検出されたPIIエンティティ""" 129 category: PIICategory 130 text: str 131 offset: int 132 length: int 133 confidence: float 134 provider_category: str = "" # プロバイダー固有のカテゴリ名
検出されたPIIエンティティ
137@dataclass 138class PIIDetectionResult: 139 """PII検出結果""" 140 success: bool 141 masked_text: str = "" 142 entities: List[PIIEntity] = field(default_factory=list) 143 categories_detected: List[PIICategory] = field(default_factory=list) 144 raw_response: Dict[str, Any] = field(default_factory=dict) 145 error: Optional[str] = None 146 error_code: Optional[str] = None
PII検出結果
149@dataclass 150class SecurityCheckResult: 151 """統合セキュリティチェック結果""" 152 allowed: bool 153 violations: List[str] = field(default_factory=list) 154 content_moderation: Optional[ContentModerationResult] = None 155 prompt_shield: Optional[PromptShieldResult] = None 156 pii_detection: Optional[PIIDetectionResult] = None
統合セキュリティチェック結果
163@dataclass 164class AzureSecurityConfig: 165 """Azure Security設定 166 167 Note: 168 api_keyはrepr=Falseでログ出力から除外されます。 169 """ 170 # Content Safety 171 content_safety_endpoint: str 172 content_safety_api_key: str = field(repr=False) 173 174 # Language Service (PII) 175 language_endpoint: Optional[str] = None 176 language_api_key: Optional[str] = field(default=None, repr=False) 177 178 # 共通オプション 179 enabled: bool = True 180 prompt_shield_enabled: bool = True 181 moderation_enabled: bool = True 182 moderation_threshold: int = 2 # 0=disabled, 2=recommended, 4=loose, 6=extreme_only 183 pii_enabled: bool = True 184 pii_confidence_threshold: float = 0.7 185 pii_mask_string: str = "***" 186 timeout: float = 30.0 187 188 @property 189 def provider(self) -> SecurityProvider: 190 return SecurityProvider.AZURE 191 192 def __str__(self) -> str: 193 return ( 194 f"AzureSecurityConfig(content_safety_endpoint={self.content_safety_endpoint!r}, " 195 f"language_endpoint={self.language_endpoint!r}, api_keys=***)" 196 )
Azure Security設定
Note: api_keyはrepr=Falseでログ出力から除外されます。
199@dataclass 200class AWSSecurityConfig: 201 """AWS Comprehend設定 202 203 Warning: 204 AWSSecurityClientは現在スタブ実装です。 205 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。 206 207 Note: 208 aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。 209 """ 210 aws_access_key_id: str = field(repr=False) 211 aws_secret_access_key: str = field(repr=False) 212 region_name: str = "us-east-1" 213 214 # 共通オプション 215 # Warning: スタブ実装のためデフォルトはFalse 216 enabled: bool = False 217 moderation_enabled: bool = True 218 pii_enabled: bool = True 219 pii_confidence_threshold: float = 0.7 220 pii_mask_string: str = "***" 221 timeout: float = 30.0 222 223 @property 224 def provider(self) -> SecurityProvider: 225 return SecurityProvider.AWS 226 227 def __str__(self) -> str: 228 return f"AWSSecurityConfig(region_name={self.region_name!r}, credentials=***)"
AWS Comprehend設定
Warning: AWSSecurityClientは現在スタブ実装です。 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。
Note: aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。
231@dataclass 232class GCPSecurityConfig: 233 """Google Cloud DLP設定 234 235 Warning: 236 GCPSecurityClientは現在スタブ実装です。 237 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。 238 239 Note: 240 credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。 241 """ 242 project_id: str 243 credentials_path: Optional[str] = field(default=None, repr=False) 244 credentials_json: Optional[str] = field(default=None, repr=False) 245 246 # 共通オプション 247 # Warning: スタブ実装のためデフォルトはFalse 248 enabled: bool = False 249 pii_enabled: bool = True 250 pii_confidence_threshold: float = 0.7 251 pii_mask_string: str = "***" 252 timeout: float = 30.0 253 254 @property 255 def provider(self) -> SecurityProvider: 256 return SecurityProvider.GCP 257 258 def __str__(self) -> str: 259 return f"GCPSecurityConfig(project_id={self.project_id!r}, credentials=***)"
Google Cloud DLP設定
Warning: GCPSecurityClientは現在スタブ実装です。 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。
Note: credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。
107class AzureSecurityClient(SecurityClientBase): 108 """ 109 Azure Security クライアント 110 111 Azure Content Safety APIとLanguage Service APIを統合し、 112 コンテンツモデレーション、Jailbreak攻撃検出、PII検出・マスキングを提供。 113 114 Example: 115 config = AzureSecurityConfig( 116 content_safety_endpoint="https://xxx.cognitiveservices.azure.com", 117 content_safety_api_key="xxx", 118 ) 119 client = AzureSecurityClient(config) 120 121 # Content Moderation 122 result = await client.check_content_moderation("some text") 123 if result.blocked: 124 print("Content blocked") 125 126 # Prompt Shield 127 result = await client.check_prompt_shield("user prompt") 128 if result.attack_detected: 129 print("Attack detected") 130 131 # PII Detection 132 result = await client.detect_pii("My email is test@example.com") 133 print(result.masked_text) # "My email is ***" 134 135 await client.close() 136 """ 137 138 def __init__(self, config: AzureSecurityConfig): 139 """ 140 Initialize Azure Security Client 141 142 Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 143 両方未設定の場合のみエラー。 144 145 Args: 146 config: AzureSecurityConfig instance 147 148 Raises: 149 SecurityConfigError: Content SafetyとPII両方が未設定の場合 150 """ 151 super().__init__(config) 152 153 # Content SafetyまたはPIIのどちらかが設定されていればOK 154 has_content_safety = bool(config.content_safety_endpoint and config.content_safety_api_key) 155 has_pii = bool(config.language_endpoint and config.language_api_key) 156 157 if not has_content_safety and not has_pii: 158 raise SecurityConfigError( 159 "At least one of Content Safety or PII (Language Service) must be configured. " 160 "Provide content_safety_endpoint/api_key or language_endpoint/api_key." 161 ) 162 163 self._azure_config: AzureSecurityConfig = config 164 self._http_client: Optional[httpx.AsyncClient] = None 165 self._enabled = config.enabled 166 self._has_content_safety = has_content_safety 167 self._has_pii = has_pii 168 169 # ログ出力 170 cs_status = f"{config.content_safety_endpoint[:30]}..." if has_content_safety else "disabled" 171 pii_status = "enabled" if has_pii else "disabled" 172 logger.info( 173 f"AzureSecurityClient initialized - " 174 f"content_safety={cs_status}, " 175 f"pii={pii_status}" 176 ) 177 178 @property 179 def provider(self) -> SecurityProvider: 180 return SecurityProvider.AZURE 181 182 def _get_http_client(self) -> httpx.AsyncClient: 183 """HTTP clientを取得(遅延初期化)""" 184 if self._http_client is None or self._http_client.is_closed: 185 self._http_client = httpx.AsyncClient(timeout=self._azure_config.timeout) 186 return self._http_client 187 188 async def check_content_moderation( 189 self, 190 text: str, 191 threshold: Optional[int] = None, 192 ) -> ContentModerationResult: 193 """ 194 有害コンテンツをチェック(Azure Content Safety Moderation API) 195 196 Args: 197 text: チェック対象テキスト(最大10,000文字) 198 threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ) 199 200 Returns: 201 ContentModerationResult: モデレーション結果 202 """ 203 if not self._enabled: 204 return ContentModerationResult(blocked=False, error="Client disabled") 205 206 if not self._has_content_safety: 207 return ContentModerationResult(blocked=False, error="Content Safety not configured") 208 209 if not self._azure_config.moderation_enabled: 210 return ContentModerationResult(blocked=False, threshold=0) 211 212 if not text: 213 return ContentModerationResult(blocked=False) 214 215 effective_threshold = threshold if threshold is not None else self._azure_config.moderation_threshold 216 if effective_threshold == 0: 217 return ContentModerationResult(blocked=False, threshold=0) 218 219 try: 220 endpoint = ( 221 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 222 f"/contentsafety/text:analyze?api-version=2024-09-01" 223 ) 224 225 request_body = { 226 "text": text[:10000], 227 "categories": ["Hate", "Sexual", "SelfHarm", "Violence"], 228 "outputType": "FourSeverityLevels", 229 } 230 231 client = self._get_http_client() 232 response = await client.post( 233 endpoint, 234 headers={ 235 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 236 "Content-Type": "application/json", 237 }, 238 json=request_body, 239 ) 240 241 if response.status_code == 200: 242 result = response.json() 243 categories_analysis = result.get("categoriesAnalysis", []) 244 245 # カテゴリ別severity取得 246 categories: Dict[ContentCategory, int] = {} 247 blocked = False 248 249 for cat in categories_analysis: 250 category_name = cat.get("category", "").lower() 251 severity = cat.get("severity", 0) 252 253 normalized_cat = self._normalize_content_category(category_name) 254 categories[normalized_cat] = severity 255 256 if severity >= effective_threshold: 257 blocked = True 258 logger.warning( 259 f"Content moderation blocked: {category_name} " 260 f"severity={severity} >= threshold={effective_threshold}" 261 ) 262 263 return ContentModerationResult( 264 blocked=blocked, 265 categories=categories, 266 threshold=effective_threshold, 267 raw_response=result, 268 ) 269 else: 270 error_msg = f"Status {response.status_code}: {response.text}" 271 logger.error(f"Content Safety Moderation API error: {error_msg}") 272 raise SecurityAPIError(error_msg, response.status_code) 273 274 except SecurityAPIError: 275 raise 276 except httpx.RequestError as e: 277 logger.error(f"Content Safety Moderation request failed: {e}") 278 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 279 except Exception as e: 280 logger.error(f"Error in content moderation: {e}") 281 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR") 282 283 async def check_prompt_shield( 284 self, 285 user_prompt: str, 286 documents: Optional[List[str]] = None, 287 ) -> PromptShieldResult: 288 """ 289 Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API) 290 291 Args: 292 user_prompt: ユーザープロンプト(最大10,000文字) 293 documents: ドキュメント内容(オプション、外部コンテンツ検証用) 294 295 Returns: 296 PromptShieldResult: Prompt Shield結果 297 """ 298 if not self._enabled: 299 return PromptShieldResult(attack_detected=False, error="Client disabled") 300 301 if not self._has_content_safety: 302 return PromptShieldResult(attack_detected=False, error="Content Safety not configured") 303 304 if not self._azure_config.prompt_shield_enabled: 305 return PromptShieldResult(attack_detected=False) 306 307 if not user_prompt: 308 return PromptShieldResult(attack_detected=False) 309 310 try: 311 endpoint = ( 312 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 313 f"/contentsafety/text:shieldPrompt?api-version=2024-09-01" 314 ) 315 316 request_body: Dict[str, Any] = { 317 "userPrompt": user_prompt[:10000], 318 } 319 320 if documents: 321 request_body["documents"] = documents 322 323 client = self._get_http_client() 324 response = await client.post( 325 endpoint, 326 headers={ 327 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 328 "Content-Type": "application/json", 329 }, 330 json=request_body, 331 ) 332 333 if response.status_code == 200: 334 result = response.json() 335 user_analysis = result.get("userPromptAnalysis", {}) 336 attack_detected = user_analysis.get("attackDetected", False) 337 338 attack_type = None 339 if attack_detected: 340 attack_type = "jailbreak" 341 logger.warning("Prompt Shield detected jailbreak attack") 342 343 return PromptShieldResult( 344 attack_detected=attack_detected, 345 attack_type=attack_type, 346 raw_response=result, 347 ) 348 else: 349 error_msg = f"Status {response.status_code}: {response.text}" 350 logger.error(f"Prompt Shield API error: {error_msg}") 351 raise SecurityAPIError(error_msg, response.status_code) 352 353 except SecurityAPIError: 354 raise 355 except httpx.RequestError as e: 356 logger.error(f"Prompt Shield request failed: {e}") 357 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 358 except Exception as e: 359 logger.error(f"Error in prompt shield: {e}") 360 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR") 361 362 async def detect_pii( 363 self, 364 text: str, 365 mask: bool = True, 366 ) -> PIIDetectionResult: 367 """ 368 PIIを検出・マスク(Azure Language Service PII Detection API) 369 370 Args: 371 text: チェック対象テキスト(最大125,000文字) 372 mask: マスク処理を行うかどうか 373 374 Returns: 375 PIIDetectionResult: PII検出結果 376 """ 377 if not self._enabled: 378 return PIIDetectionResult(success=False, error="Client disabled") 379 380 if not self._has_pii: 381 return PIIDetectionResult(success=False, error="PII (Language Service) not configured") 382 383 if not self._azure_config.pii_enabled: 384 return PIIDetectionResult(success=True, masked_text=text) 385 386 if not text: 387 return PIIDetectionResult(success=True, masked_text="") 388 389 try: 390 endpoint = ( 391 f"{self._azure_config.language_endpoint.rstrip('/')}" 392 f"/language/:analyze-text?api-version=2024-11-01" 393 ) 394 395 request_body = { 396 "kind": "PiiEntityRecognition", 397 "parameters": { 398 "modelVersion": "latest", 399 "domain": "none", 400 "loggingOptOut": True, 401 "piiCategories": AZURE_PII_CATEGORIES, 402 }, 403 "analysisInput": { 404 "documents": [ 405 { 406 "id": "1", 407 "language": "ja", 408 "text": text[:125000], 409 } 410 ] 411 }, 412 } 413 414 client = self._get_http_client() 415 response = await client.post( 416 endpoint, 417 headers={ 418 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.language_api_key), 419 "Content-Type": "application/json", 420 }, 421 json=request_body, 422 ) 423 424 if response.status_code == 200: 425 result = response.json() 426 documents = result.get("results", {}).get("documents", []) 427 428 if not documents: 429 return PIIDetectionResult( 430 success=True, 431 masked_text="" if mask else text, # データ保護優先 432 ) 433 434 doc_result = documents[0] 435 entities_raw = doc_result.get("entities", []) 436 437 # エンティティをPIIEntityに変換 438 entities: List[PIIEntity] = [] 439 categories_detected: List[PIICategory] = [] 440 441 for entity in entities_raw: 442 confidence = entity.get("confidenceScore", 0) 443 if confidence < self._azure_config.pii_confidence_threshold: 444 continue 445 446 provider_category = entity.get("category", "Unknown") 447 normalized_category = self._normalize_pii_category(provider_category) 448 449 entities.append(PIIEntity( 450 category=normalized_category, 451 text=entity.get("text", ""), 452 offset=entity.get("offset", 0), 453 length=entity.get("length", 0), 454 confidence=confidence, 455 provider_category=provider_category, 456 )) 457 458 if normalized_category not in categories_detected: 459 categories_detected.append(normalized_category) 460 461 # マスク処理 462 masked_text = text 463 if mask and entities: 464 # オフセット順で降順ソート(後ろから処理) 465 sorted_entities = sorted(entities, key=lambda x: x.offset, reverse=True) 466 for entity in sorted_entities: 467 masked_text = ( 468 masked_text[:entity.offset] 469 + self._azure_config.pii_mask_string 470 + masked_text[entity.offset + entity.length:] 471 ) 472 473 if entities: 474 logger.info(f"PII detected and masked: {categories_detected}") 475 476 return PIIDetectionResult( 477 success=True, 478 masked_text=masked_text, 479 entities=entities, 480 categories_detected=categories_detected, 481 raw_response=result, 482 ) 483 484 else: 485 error_msg = f"Status {response.status_code}: {response.text}" 486 logger.error(f"Language Service PII API error: {error_msg}") 487 # PII検出エラー時は空文字を返す(データ保護優先) 488 return PIIDetectionResult( 489 success=False, 490 masked_text="", 491 error=error_msg, 492 error_code="API_ERROR", 493 ) 494 495 except httpx.RequestError as e: 496 logger.error(f"Language Service PII request failed: {e}") 497 return PIIDetectionResult( 498 success=False, 499 masked_text="", 500 error=str(e), 501 error_code="REQUEST_ERROR", 502 ) 503 except Exception as e: 504 logger.error(f"Error in PII detection: {e}") 505 return PIIDetectionResult( 506 success=False, 507 masked_text="", 508 error=str(e), 509 error_code="UNKNOWN_ERROR", 510 ) 511 512 async def close(self) -> None: 513 """クライアントリソースをクローズ 514 515 Note: 516 例外が発生してもリソースを確実にクリーンアップします。 517 """ 518 try: 519 if self._http_client and not self._http_client.is_closed: 520 await self._http_client.aclose() 521 except Exception as e: 522 logger.warning(f"Error closing HTTP client: {e}") 523 finally: 524 self._http_client = None 525 self._enabled = False 526 logger.debug("AzureSecurityClient closed")
Azure Security クライアント
Azure Content Safety APIとLanguage Service APIを統合し、 コンテンツモデレーション、Jailbreak攻撃検出、PII検出・マスキングを提供。
Example: config = AzureSecurityConfig( content_safety_endpoint="https://xxx.cognitiveservices.azure.com", content_safety_api_key="xxx", ) client = AzureSecurityClient(config)
# Content Moderation
result = await client.check_content_moderation("some text")
if result.blocked:
print("Content blocked")
# Prompt Shield
result = await client.check_prompt_shield("user prompt")
if result.attack_detected:
print("Attack detected")
# PII Detection
result = await client.detect_pii("My email is test@example.com")
print(result.masked_text) # "My email is ***"
await client.close()
138 def __init__(self, config: AzureSecurityConfig): 139 """ 140 Initialize Azure Security Client 141 142 Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 143 両方未設定の場合のみエラー。 144 145 Args: 146 config: AzureSecurityConfig instance 147 148 Raises: 149 SecurityConfigError: Content SafetyとPII両方が未設定の場合 150 """ 151 super().__init__(config) 152 153 # Content SafetyまたはPIIのどちらかが設定されていればOK 154 has_content_safety = bool(config.content_safety_endpoint and config.content_safety_api_key) 155 has_pii = bool(config.language_endpoint and config.language_api_key) 156 157 if not has_content_safety and not has_pii: 158 raise SecurityConfigError( 159 "At least one of Content Safety or PII (Language Service) must be configured. " 160 "Provide content_safety_endpoint/api_key or language_endpoint/api_key." 161 ) 162 163 self._azure_config: AzureSecurityConfig = config 164 self._http_client: Optional[httpx.AsyncClient] = None 165 self._enabled = config.enabled 166 self._has_content_safety = has_content_safety 167 self._has_pii = has_pii 168 169 # ログ出力 170 cs_status = f"{config.content_safety_endpoint[:30]}..." if has_content_safety else "disabled" 171 pii_status = "enabled" if has_pii else "disabled" 172 logger.info( 173 f"AzureSecurityClient initialized - " 174 f"content_safety={cs_status}, " 175 f"pii={pii_status}" 176 )
Initialize Azure Security Client
Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 両方未設定の場合のみエラー。
Args: config: AzureSecurityConfig instance
Raises: SecurityConfigError: Content SafetyとPII両方が未設定の場合
188 async def check_content_moderation( 189 self, 190 text: str, 191 threshold: Optional[int] = None, 192 ) -> ContentModerationResult: 193 """ 194 有害コンテンツをチェック(Azure Content Safety Moderation API) 195 196 Args: 197 text: チェック対象テキスト(最大10,000文字) 198 threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ) 199 200 Returns: 201 ContentModerationResult: モデレーション結果 202 """ 203 if not self._enabled: 204 return ContentModerationResult(blocked=False, error="Client disabled") 205 206 if not self._has_content_safety: 207 return ContentModerationResult(blocked=False, error="Content Safety not configured") 208 209 if not self._azure_config.moderation_enabled: 210 return ContentModerationResult(blocked=False, threshold=0) 211 212 if not text: 213 return ContentModerationResult(blocked=False) 214 215 effective_threshold = threshold if threshold is not None else self._azure_config.moderation_threshold 216 if effective_threshold == 0: 217 return ContentModerationResult(blocked=False, threshold=0) 218 219 try: 220 endpoint = ( 221 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 222 f"/contentsafety/text:analyze?api-version=2024-09-01" 223 ) 224 225 request_body = { 226 "text": text[:10000], 227 "categories": ["Hate", "Sexual", "SelfHarm", "Violence"], 228 "outputType": "FourSeverityLevels", 229 } 230 231 client = self._get_http_client() 232 response = await client.post( 233 endpoint, 234 headers={ 235 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 236 "Content-Type": "application/json", 237 }, 238 json=request_body, 239 ) 240 241 if response.status_code == 200: 242 result = response.json() 243 categories_analysis = result.get("categoriesAnalysis", []) 244 245 # カテゴリ別severity取得 246 categories: Dict[ContentCategory, int] = {} 247 blocked = False 248 249 for cat in categories_analysis: 250 category_name = cat.get("category", "").lower() 251 severity = cat.get("severity", 0) 252 253 normalized_cat = self._normalize_content_category(category_name) 254 categories[normalized_cat] = severity 255 256 if severity >= effective_threshold: 257 blocked = True 258 logger.warning( 259 f"Content moderation blocked: {category_name} " 260 f"severity={severity} >= threshold={effective_threshold}" 261 ) 262 263 return ContentModerationResult( 264 blocked=blocked, 265 categories=categories, 266 threshold=effective_threshold, 267 raw_response=result, 268 ) 269 else: 270 error_msg = f"Status {response.status_code}: {response.text}" 271 logger.error(f"Content Safety Moderation API error: {error_msg}") 272 raise SecurityAPIError(error_msg, response.status_code) 273 274 except SecurityAPIError: 275 raise 276 except httpx.RequestError as e: 277 logger.error(f"Content Safety Moderation request failed: {e}") 278 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 279 except Exception as e: 280 logger.error(f"Error in content moderation: {e}") 281 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR")
有害コンテンツをチェック(Azure Content Safety Moderation API)
Args: text: チェック対象テキスト(最大10,000文字) threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ)
Returns: ContentModerationResult: モデレーション結果
283 async def check_prompt_shield( 284 self, 285 user_prompt: str, 286 documents: Optional[List[str]] = None, 287 ) -> PromptShieldResult: 288 """ 289 Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API) 290 291 Args: 292 user_prompt: ユーザープロンプト(最大10,000文字) 293 documents: ドキュメント内容(オプション、外部コンテンツ検証用) 294 295 Returns: 296 PromptShieldResult: Prompt Shield結果 297 """ 298 if not self._enabled: 299 return PromptShieldResult(attack_detected=False, error="Client disabled") 300 301 if not self._has_content_safety: 302 return PromptShieldResult(attack_detected=False, error="Content Safety not configured") 303 304 if not self._azure_config.prompt_shield_enabled: 305 return PromptShieldResult(attack_detected=False) 306 307 if not user_prompt: 308 return PromptShieldResult(attack_detected=False) 309 310 try: 311 endpoint = ( 312 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 313 f"/contentsafety/text:shieldPrompt?api-version=2024-09-01" 314 ) 315 316 request_body: Dict[str, Any] = { 317 "userPrompt": user_prompt[:10000], 318 } 319 320 if documents: 321 request_body["documents"] = documents 322 323 client = self._get_http_client() 324 response = await client.post( 325 endpoint, 326 headers={ 327 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 328 "Content-Type": "application/json", 329 }, 330 json=request_body, 331 ) 332 333 if response.status_code == 200: 334 result = response.json() 335 user_analysis = result.get("userPromptAnalysis", {}) 336 attack_detected = user_analysis.get("attackDetected", False) 337 338 attack_type = None 339 if attack_detected: 340 attack_type = "jailbreak" 341 logger.warning("Prompt Shield detected jailbreak attack") 342 343 return PromptShieldResult( 344 attack_detected=attack_detected, 345 attack_type=attack_type, 346 raw_response=result, 347 ) 348 else: 349 error_msg = f"Status {response.status_code}: {response.text}" 350 logger.error(f"Prompt Shield API error: {error_msg}") 351 raise SecurityAPIError(error_msg, response.status_code) 352 353 except SecurityAPIError: 354 raise 355 except httpx.RequestError as e: 356 logger.error(f"Prompt Shield request failed: {e}") 357 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 358 except Exception as e: 359 logger.error(f"Error in prompt shield: {e}") 360 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR")
Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API)
Args: user_prompt: ユーザープロンプト(最大10,000文字) documents: ドキュメント内容(オプション、外部コンテンツ検証用)
Returns: PromptShieldResult: Prompt Shield結果
362 async def detect_pii( 363 self, 364 text: str, 365 mask: bool = True, 366 ) -> PIIDetectionResult: 367 """ 368 PIIを検出・マスク(Azure Language Service PII Detection API) 369 370 Args: 371 text: チェック対象テキスト(最大125,000文字) 372 mask: マスク処理を行うかどうか 373 374 Returns: 375 PIIDetectionResult: PII検出結果 376 """ 377 if not self._enabled: 378 return PIIDetectionResult(success=False, error="Client disabled") 379 380 if not self._has_pii: 381 return PIIDetectionResult(success=False, error="PII (Language Service) not configured") 382 383 if not self._azure_config.pii_enabled: 384 return PIIDetectionResult(success=True, masked_text=text) 385 386 if not text: 387 return PIIDetectionResult(success=True, masked_text="") 388 389 try: 390 endpoint = ( 391 f"{self._azure_config.language_endpoint.rstrip('/')}" 392 f"/language/:analyze-text?api-version=2024-11-01" 393 ) 394 395 request_body = { 396 "kind": "PiiEntityRecognition", 397 "parameters": { 398 "modelVersion": "latest", 399 "domain": "none", 400 "loggingOptOut": True, 401 "piiCategories": AZURE_PII_CATEGORIES, 402 }, 403 "analysisInput": { 404 "documents": [ 405 { 406 "id": "1", 407 "language": "ja", 408 "text": text[:125000], 409 } 410 ] 411 }, 412 } 413 414 client = self._get_http_client() 415 response = await client.post( 416 endpoint, 417 headers={ 418 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.language_api_key), 419 "Content-Type": "application/json", 420 }, 421 json=request_body, 422 ) 423 424 if response.status_code == 200: 425 result = response.json() 426 documents = result.get("results", {}).get("documents", []) 427 428 if not documents: 429 return PIIDetectionResult( 430 success=True, 431 masked_text="" if mask else text, # データ保護優先 432 ) 433 434 doc_result = documents[0] 435 entities_raw = doc_result.get("entities", []) 436 437 # エンティティをPIIEntityに変換 438 entities: List[PIIEntity] = [] 439 categories_detected: List[PIICategory] = [] 440 441 for entity in entities_raw: 442 confidence = entity.get("confidenceScore", 0) 443 if confidence < self._azure_config.pii_confidence_threshold: 444 continue 445 446 provider_category = entity.get("category", "Unknown") 447 normalized_category = self._normalize_pii_category(provider_category) 448 449 entities.append(PIIEntity( 450 category=normalized_category, 451 text=entity.get("text", ""), 452 offset=entity.get("offset", 0), 453 length=entity.get("length", 0), 454 confidence=confidence, 455 provider_category=provider_category, 456 )) 457 458 if normalized_category not in categories_detected: 459 categories_detected.append(normalized_category) 460 461 # マスク処理 462 masked_text = text 463 if mask and entities: 464 # オフセット順で降順ソート(後ろから処理) 465 sorted_entities = sorted(entities, key=lambda x: x.offset, reverse=True) 466 for entity in sorted_entities: 467 masked_text = ( 468 masked_text[:entity.offset] 469 + self._azure_config.pii_mask_string 470 + masked_text[entity.offset + entity.length:] 471 ) 472 473 if entities: 474 logger.info(f"PII detected and masked: {categories_detected}") 475 476 return PIIDetectionResult( 477 success=True, 478 masked_text=masked_text, 479 entities=entities, 480 categories_detected=categories_detected, 481 raw_response=result, 482 ) 483 484 else: 485 error_msg = f"Status {response.status_code}: {response.text}" 486 logger.error(f"Language Service PII API error: {error_msg}") 487 # PII検出エラー時は空文字を返す(データ保護優先) 488 return PIIDetectionResult( 489 success=False, 490 masked_text="", 491 error=error_msg, 492 error_code="API_ERROR", 493 ) 494 495 except httpx.RequestError as e: 496 logger.error(f"Language Service PII request failed: {e}") 497 return PIIDetectionResult( 498 success=False, 499 masked_text="", 500 error=str(e), 501 error_code="REQUEST_ERROR", 502 ) 503 except Exception as e: 504 logger.error(f"Error in PII detection: {e}") 505 return PIIDetectionResult( 506 success=False, 507 masked_text="", 508 error=str(e), 509 error_code="UNKNOWN_ERROR", 510 )
PIIを検出・マスク(Azure Language Service PII Detection API)
Args: text: チェック対象テキスト(最大125,000文字) mask: マスク処理を行うかどうか
Returns: PIIDetectionResult: PII検出結果
512 async def close(self) -> None: 513 """クライアントリソースをクローズ 514 515 Note: 516 例外が発生してもリソースを確実にクリーンアップします。 517 """ 518 try: 519 if self._http_client and not self._http_client.is_closed: 520 await self._http_client.aclose() 521 except Exception as e: 522 logger.warning(f"Error closing HTTP client: {e}") 523 finally: 524 self._http_client = None 525 self._enabled = False 526 logger.debug("AzureSecurityClient closed")
クライアントリソースをクローズ
Note: 例外が発生してもリソースを確実にクリーンアップします。
58class AWSSecurityClient(SecurityClientBase): 59 """ 60 AWS Comprehend セキュリティクライアント(スタブ実装) 61 62 Note: 63 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 64 detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。 65 66 Example: 67 config = AWSSecurityConfig( 68 aws_access_key_id="AKIA...", 69 aws_secret_access_key="...", 70 region_name="us-east-1", 71 ) 72 client = AWSSecurityClient(config) 73 result = await client.detect_pii("My email is test@example.com") 74 """ 75 76 def __init__(self, config: AWSSecurityConfig): 77 """ 78 Initialize AWS Comprehend Security Client 79 80 Args: 81 config: AWSSecurityConfig instance 82 83 Raises: 84 SecurityConfigError: 設定が不正な場合 85 """ 86 super().__init__(config) 87 88 if not BOTO3_AVAILABLE: 89 raise SecurityConfigError( 90 "boto3 package is not installed. " 91 "Install with: pip install boto3" 92 ) 93 94 if not config.aws_access_key_id or not config.aws_secret_access_key: 95 raise SecurityConfigError( 96 "AWS requires aws_access_key_id and aws_secret_access_key" 97 ) 98 99 self._aws_config: AWSSecurityConfig = config 100 self._comprehend_client = None 101 102 try: 103 session = boto3.Session( 104 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 105 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 106 region_name=config.region_name, 107 ) 108 109 self._comprehend_client = session.client('comprehend') 110 self._enabled = config.enabled 111 112 logger.warning( 113 "AWSSecurityClient is a STUB implementation. " 114 "detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。" 115 ) 116 logger.info( 117 f"AWSSecurityClient initialized - region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize AWS Comprehend client: {e}") 122 raise SecurityConfigError(f"Failed to connect to AWS Comprehend: {e}") from e 123 124 @property 125 def provider(self) -> SecurityProvider: 126 return SecurityProvider.AWS 127 128 async def check_content_moderation( 129 self, 130 text: str, 131 threshold: Optional[int] = None, 132 ) -> ContentModerationResult: 133 """ 134 有害コンテンツをチェック(AWS Comprehend DetectToxicContent) 135 136 Note: 137 スタブ実装 - NOT_IMPLEMENTEDを返します 138 139 TODO: 140 - DetectToxicContent APIの実装 141 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等 142 """ 143 # TODO: 本格実装 144 # aws comprehend detect-toxic-content --text-segments "Text=<text>" --language-code ja 145 return ContentModerationResult( 146 blocked=False, 147 error="AWS Comprehend content moderation not yet implemented", 148 error_code="NOT_IMPLEMENTED", 149 ) 150 151 async def check_prompt_shield( 152 self, 153 user_prompt: str, 154 documents: Optional[List[str]] = None, 155 ) -> PromptShieldResult: 156 """ 157 Jailbreak攻撃を検出 158 159 Note: 160 AWS ComprehendにはPrompt Shield相当の機能がないため、 161 常にattack_detected=Falseを返します。 162 """ 163 # AWS Comprehendには直接的なPrompt Shield機能がない 164 return PromptShieldResult( 165 attack_detected=False, 166 error="AWS Comprehend does not support Prompt Shield functionality", 167 error_code="NOT_SUPPORTED", 168 ) 169 170 async def detect_pii( 171 self, 172 text: str, 173 mask: bool = True, 174 ) -> PIIDetectionResult: 175 """ 176 PIIを検出・マスク(AWS Comprehend DetectPiiEntities) 177 178 Note: 179 スタブ実装 - NOT_IMPLEMENTEDを返します 180 181 TODO: 182 - DetectPiiEntities APIの実装 183 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, 184 CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, 185 EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, 186 PIN, SSN, URL, USERNAME等 187 """ 188 # TODO: 本格実装 189 # import asyncio 190 # response = await asyncio.get_event_loop().run_in_executor( 191 # None, 192 # lambda: self._comprehend_client.detect_pii_entities( 193 # Text=text[:5000], # API制限 194 # LanguageCode='ja' 195 # ) 196 # ) 197 return PIIDetectionResult( 198 success=False, 199 masked_text="", 200 error="AWS Comprehend PII detection not yet implemented", 201 error_code="NOT_IMPLEMENTED", 202 ) 203 204 async def close(self) -> None: 205 """クライアントリソースをクローズ 206 207 Note: 208 boto3クライアントは明示的なclose不要ですが、 209 参照をクリアしてリソースを確実に解放します。 210 """ 211 try: 212 # boto3クライアントは明示的なclose不要だが、参照をクリア 213 pass 214 except Exception as e: 215 logger.warning(f"Error during cleanup: {e}") 216 finally: 217 self._comprehend_client = None 218 self._enabled = False 219 logger.debug("AWSSecurityClient closed")
AWS Comprehend セキュリティクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。
Example: config = AWSSecurityConfig( aws_access_key_id="AKIA...", aws_secret_access_key="...", region_name="us-east-1", ) client = AWSSecurityClient(config) result = await client.detect_pii("My email is test@example.com")
76 def __init__(self, config: AWSSecurityConfig): 77 """ 78 Initialize AWS Comprehend Security Client 79 80 Args: 81 config: AWSSecurityConfig instance 82 83 Raises: 84 SecurityConfigError: 設定が不正な場合 85 """ 86 super().__init__(config) 87 88 if not BOTO3_AVAILABLE: 89 raise SecurityConfigError( 90 "boto3 package is not installed. " 91 "Install with: pip install boto3" 92 ) 93 94 if not config.aws_access_key_id or not config.aws_secret_access_key: 95 raise SecurityConfigError( 96 "AWS requires aws_access_key_id and aws_secret_access_key" 97 ) 98 99 self._aws_config: AWSSecurityConfig = config 100 self._comprehend_client = None 101 102 try: 103 session = boto3.Session( 104 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 105 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 106 region_name=config.region_name, 107 ) 108 109 self._comprehend_client = session.client('comprehend') 110 self._enabled = config.enabled 111 112 logger.warning( 113 "AWSSecurityClient is a STUB implementation. " 114 "detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。" 115 ) 116 logger.info( 117 f"AWSSecurityClient initialized - region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize AWS Comprehend client: {e}") 122 raise SecurityConfigError(f"Failed to connect to AWS Comprehend: {e}") from e
Initialize AWS Comprehend Security Client
Args: config: AWSSecurityConfig instance
Raises: SecurityConfigError: 設定が不正な場合
128 async def check_content_moderation( 129 self, 130 text: str, 131 threshold: Optional[int] = None, 132 ) -> ContentModerationResult: 133 """ 134 有害コンテンツをチェック(AWS Comprehend DetectToxicContent) 135 136 Note: 137 スタブ実装 - NOT_IMPLEMENTEDを返します 138 139 TODO: 140 - DetectToxicContent APIの実装 141 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等 142 """ 143 # TODO: 本格実装 144 # aws comprehend detect-toxic-content --text-segments "Text=<text>" --language-code ja 145 return ContentModerationResult( 146 blocked=False, 147 error="AWS Comprehend content moderation not yet implemented", 148 error_code="NOT_IMPLEMENTED", 149 )
有害コンテンツをチェック(AWS Comprehend DetectToxicContent)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - DetectToxicContent APIの実装 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等
151 async def check_prompt_shield( 152 self, 153 user_prompt: str, 154 documents: Optional[List[str]] = None, 155 ) -> PromptShieldResult: 156 """ 157 Jailbreak攻撃を検出 158 159 Note: 160 AWS ComprehendにはPrompt Shield相当の機能がないため、 161 常にattack_detected=Falseを返します。 162 """ 163 # AWS Comprehendには直接的なPrompt Shield機能がない 164 return PromptShieldResult( 165 attack_detected=False, 166 error="AWS Comprehend does not support Prompt Shield functionality", 167 error_code="NOT_SUPPORTED", 168 )
Jailbreak攻撃を検出
Note: AWS ComprehendにはPrompt Shield相当の機能がないため、 常にattack_detected=Falseを返します。
170 async def detect_pii( 171 self, 172 text: str, 173 mask: bool = True, 174 ) -> PIIDetectionResult: 175 """ 176 PIIを検出・マスク(AWS Comprehend DetectPiiEntities) 177 178 Note: 179 スタブ実装 - NOT_IMPLEMENTEDを返します 180 181 TODO: 182 - DetectPiiEntities APIの実装 183 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, 184 CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, 185 EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, 186 PIN, SSN, URL, USERNAME等 187 """ 188 # TODO: 本格実装 189 # import asyncio 190 # response = await asyncio.get_event_loop().run_in_executor( 191 # None, 192 # lambda: self._comprehend_client.detect_pii_entities( 193 # Text=text[:5000], # API制限 194 # LanguageCode='ja' 195 # ) 196 # ) 197 return PIIDetectionResult( 198 success=False, 199 masked_text="", 200 error="AWS Comprehend PII detection not yet implemented", 201 error_code="NOT_IMPLEMENTED", 202 )
PIIを検出・マスク(AWS Comprehend DetectPiiEntities)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - DetectPiiEntities APIの実装 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, PIN, SSN, URL, USERNAME等
204 async def close(self) -> None: 205 """クライアントリソースをクローズ 206 207 Note: 208 boto3クライアントは明示的なclose不要ですが、 209 参照をクリアしてリソースを確実に解放します。 210 """ 211 try: 212 # boto3クライアントは明示的なclose不要だが、参照をクリア 213 pass 214 except Exception as e: 215 logger.warning(f"Error during cleanup: {e}") 216 finally: 217 self._comprehend_client = None 218 self._enabled = False 219 logger.debug("AWSSecurityClient closed")
クライアントリソースをクローズ
Note: boto3クライアントは明示的なclose不要ですが、 参照をクリアしてリソースを確実に解放します。
58class GCPSecurityClient(SecurityClientBase): 59 """ 60 Google Cloud DLP セキュリティクライアント(スタブ実装) 61 62 Note: 63 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 64 detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。 65 66 Example: 67 config = GCPSecurityConfig( 68 project_id="my-project", 69 credentials_path="/path/to/service-account.json", 70 ) 71 client = GCPSecurityClient(config) 72 result = await client.detect_pii("My email is test@example.com") 73 """ 74 75 def __init__(self, config: GCPSecurityConfig): 76 """ 77 Initialize Google Cloud DLP Security Client 78 79 Args: 80 config: GCPSecurityConfig instance 81 82 Raises: 83 SecurityConfigError: 設定が不正な場合 84 """ 85 super().__init__(config) 86 87 if not GCP_DLP_AVAILABLE: 88 raise SecurityConfigError( 89 "google-cloud-dlp package is not installed. " 90 "Install with: pip install google-cloud-dlp" 91 ) 92 93 if not config.project_id: 94 raise SecurityConfigError("GCP project_id is required") 95 96 self._gcp_config: GCPSecurityConfig = config 97 self._dlp_client = None 98 99 try: 100 # 認証情報の設定 101 if config.credentials_path: 102 credentials = service_account.Credentials.from_service_account_file( 103 config.credentials_path 104 ) 105 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 106 elif config.credentials_json: 107 import json 108 credentials_info = json.loads(config.credentials_json) 109 credentials = service_account.Credentials.from_service_account_info( 110 credentials_info 111 ) 112 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 113 else: 114 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 115 self._dlp_client = dlp_v2.DlpServiceClient() 116 117 self._enabled = config.enabled 118 119 logger.warning( 120 "GCPSecurityClient is a STUB implementation. " 121 "detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。" 122 ) 123 logger.info( 124 f"GCPSecurityClient initialized - project={config.project_id}" 125 ) 126 127 except Exception as e: 128 logger.error(f"Failed to initialize GCP DLP client: {e}") 129 raise SecurityConfigError(f"Failed to connect to GCP DLP: {e}") from e 130 131 @property 132 def provider(self) -> SecurityProvider: 133 return SecurityProvider.GCP 134 135 async def check_content_moderation( 136 self, 137 text: str, 138 threshold: Optional[int] = None, 139 ) -> ContentModerationResult: 140 """ 141 有害コンテンツをチェック 142 143 Note: 144 GCP DLPにはContent Moderation機能がないため、 145 常にblocked=Falseを返します。 146 ※ コンテンツモデレーションにはPerspective APIを検討 147 """ 148 # GCP DLPには直接的なContent Moderation機能がない 149 # Perspective API (Google Jigsaw) または Vertex AI Safety Filtersが必要 150 return ContentModerationResult( 151 blocked=False, 152 error="GCP DLP does not support content moderation. Consider Perspective API.", 153 error_code="NOT_SUPPORTED", 154 ) 155 156 async def check_prompt_shield( 157 self, 158 user_prompt: str, 159 documents: Optional[List[str]] = None, 160 ) -> PromptShieldResult: 161 """ 162 Jailbreak攻撃を検出 163 164 Note: 165 GCP DLPにはPrompt Shield相当の機能がないため、 166 常にattack_detected=Falseを返します。 167 """ 168 # GCP DLPには直接的なPrompt Shield機能がない 169 return PromptShieldResult( 170 attack_detected=False, 171 error="GCP DLP does not support Prompt Shield functionality", 172 error_code="NOT_SUPPORTED", 173 ) 174 175 async def detect_pii( 176 self, 177 text: str, 178 mask: bool = True, 179 ) -> PIIDetectionResult: 180 """ 181 PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content) 182 183 Note: 184 スタブ実装 - NOT_IMPLEMENTEDを返します 185 186 TODO: 187 - inspect_content APIの実装 188 - deidentify_content APIの実装 189 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, 190 STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, 191 JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, 192 JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等 193 """ 194 # TODO: 本格実装 195 # parent = f"projects/{self._gcp_config.project_id}" 196 # 197 # inspect_config = { 198 # "info_types": [ 199 # {"name": "CREDIT_CARD_NUMBER"}, 200 # {"name": "EMAIL_ADDRESS"}, 201 # {"name": "PHONE_NUMBER"}, 202 # {"name": "PERSON_NAME"}, 203 # ... 204 # ], 205 # "min_likelihood": "POSSIBLE", # VERY_UNLIKELY, UNLIKELY, POSSIBLE, LIKELY, VERY_LIKELY 206 # } 207 # 208 # item = {"value": text} 209 # 210 # # asyncioブロック防止 211 # response = await asyncio.get_event_loop().run_in_executor( 212 # None, 213 # lambda: self._dlp_client.inspect_content( 214 # request={"parent": parent, "inspect_config": inspect_config, "item": item} 215 # ) 216 # ) 217 # 218 # # deidentify_content for masking 219 # deidentify_config = { 220 # "info_type_transformations": { 221 # "transformations": [ 222 # { 223 # "primitive_transformation": { 224 # "replace_config": {"new_value": {"string_value": self._gcp_config.pii_mask_string}} 225 # } 226 # } 227 # ] 228 # } 229 # } 230 return PIIDetectionResult( 231 success=False, 232 masked_text="", 233 error="GCP DLP PII detection not yet implemented", 234 error_code="NOT_IMPLEMENTED", 235 ) 236 237 async def close(self) -> None: 238 """クライアントリソースをクローズ 239 240 Note: 241 gRPCクライアントのクローズ時に例外が発生しても 242 リソースを確実にクリーンアップします。 243 """ 244 try: 245 # gRPCクライアントのクローズ 246 if self._dlp_client: 247 # DlpServiceClientにはclose()メソッドがある場合がある 248 if hasattr(self._dlp_client, 'transport') and hasattr(self._dlp_client.transport, 'close'): 249 self._dlp_client.transport.close() 250 except Exception as e: 251 logger.warning(f"Error closing DLP client: {e}") 252 finally: 253 self._dlp_client = None 254 self._enabled = False 255 logger.debug("GCPSecurityClient closed")
Google Cloud DLP セキュリティクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。
Example: config = GCPSecurityConfig( project_id="my-project", credentials_path="/path/to/service-account.json", ) client = GCPSecurityClient(config) result = await client.detect_pii("My email is test@example.com")
75 def __init__(self, config: GCPSecurityConfig): 76 """ 77 Initialize Google Cloud DLP Security Client 78 79 Args: 80 config: GCPSecurityConfig instance 81 82 Raises: 83 SecurityConfigError: 設定が不正な場合 84 """ 85 super().__init__(config) 86 87 if not GCP_DLP_AVAILABLE: 88 raise SecurityConfigError( 89 "google-cloud-dlp package is not installed. " 90 "Install with: pip install google-cloud-dlp" 91 ) 92 93 if not config.project_id: 94 raise SecurityConfigError("GCP project_id is required") 95 96 self._gcp_config: GCPSecurityConfig = config 97 self._dlp_client = None 98 99 try: 100 # 認証情報の設定 101 if config.credentials_path: 102 credentials = service_account.Credentials.from_service_account_file( 103 config.credentials_path 104 ) 105 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 106 elif config.credentials_json: 107 import json 108 credentials_info = json.loads(config.credentials_json) 109 credentials = service_account.Credentials.from_service_account_info( 110 credentials_info 111 ) 112 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 113 else: 114 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 115 self._dlp_client = dlp_v2.DlpServiceClient() 116 117 self._enabled = config.enabled 118 119 logger.warning( 120 "GCPSecurityClient is a STUB implementation. " 121 "detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。" 122 ) 123 logger.info( 124 f"GCPSecurityClient initialized - project={config.project_id}" 125 ) 126 127 except Exception as e: 128 logger.error(f"Failed to initialize GCP DLP client: {e}") 129 raise SecurityConfigError(f"Failed to connect to GCP DLP: {e}") from e
Initialize Google Cloud DLP Security Client
Args: config: GCPSecurityConfig instance
Raises: SecurityConfigError: 設定が不正な場合
135 async def check_content_moderation( 136 self, 137 text: str, 138 threshold: Optional[int] = None, 139 ) -> ContentModerationResult: 140 """ 141 有害コンテンツをチェック 142 143 Note: 144 GCP DLPにはContent Moderation機能がないため、 145 常にblocked=Falseを返します。 146 ※ コンテンツモデレーションにはPerspective APIを検討 147 """ 148 # GCP DLPには直接的なContent Moderation機能がない 149 # Perspective API (Google Jigsaw) または Vertex AI Safety Filtersが必要 150 return ContentModerationResult( 151 blocked=False, 152 error="GCP DLP does not support content moderation. Consider Perspective API.", 153 error_code="NOT_SUPPORTED", 154 )
有害コンテンツをチェック
Note: GCP DLPにはContent Moderation機能がないため、 常にblocked=Falseを返します。 ※ コンテンツモデレーションにはPerspective APIを検討
156 async def check_prompt_shield( 157 self, 158 user_prompt: str, 159 documents: Optional[List[str]] = None, 160 ) -> PromptShieldResult: 161 """ 162 Jailbreak攻撃を検出 163 164 Note: 165 GCP DLPにはPrompt Shield相当の機能がないため、 166 常にattack_detected=Falseを返します。 167 """ 168 # GCP DLPには直接的なPrompt Shield機能がない 169 return PromptShieldResult( 170 attack_detected=False, 171 error="GCP DLP does not support Prompt Shield functionality", 172 error_code="NOT_SUPPORTED", 173 )
Jailbreak攻撃を検出
Note: GCP DLPにはPrompt Shield相当の機能がないため、 常にattack_detected=Falseを返します。
175 async def detect_pii( 176 self, 177 text: str, 178 mask: bool = True, 179 ) -> PIIDetectionResult: 180 """ 181 PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content) 182 183 Note: 184 スタブ実装 - NOT_IMPLEMENTEDを返します 185 186 TODO: 187 - inspect_content APIの実装 188 - deidentify_content APIの実装 189 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, 190 STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, 191 JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, 192 JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等 193 """ 194 # TODO: 本格実装 195 # parent = f"projects/{self._gcp_config.project_id}" 196 # 197 # inspect_config = { 198 # "info_types": [ 199 # {"name": "CREDIT_CARD_NUMBER"}, 200 # {"name": "EMAIL_ADDRESS"}, 201 # {"name": "PHONE_NUMBER"}, 202 # {"name": "PERSON_NAME"}, 203 # ... 204 # ], 205 # "min_likelihood": "POSSIBLE", # VERY_UNLIKELY, UNLIKELY, POSSIBLE, LIKELY, VERY_LIKELY 206 # } 207 # 208 # item = {"value": text} 209 # 210 # # asyncioブロック防止 211 # response = await asyncio.get_event_loop().run_in_executor( 212 # None, 213 # lambda: self._dlp_client.inspect_content( 214 # request={"parent": parent, "inspect_config": inspect_config, "item": item} 215 # ) 216 # ) 217 # 218 # # deidentify_content for masking 219 # deidentify_config = { 220 # "info_type_transformations": { 221 # "transformations": [ 222 # { 223 # "primitive_transformation": { 224 # "replace_config": {"new_value": {"string_value": self._gcp_config.pii_mask_string}} 225 # } 226 # } 227 # ] 228 # } 229 # } 230 return PIIDetectionResult( 231 success=False, 232 masked_text="", 233 error="GCP DLP PII detection not yet implemented", 234 error_code="NOT_IMPLEMENTED", 235 )
PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - inspect_content APIの実装 - deidentify_content APIの実装 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等
237 async def close(self) -> None: 238 """クライアントリソースをクローズ 239 240 Note: 241 gRPCクライアントのクローズ時に例外が発生しても 242 リソースを確実にクリーンアップします。 243 """ 244 try: 245 # gRPCクライアントのクローズ 246 if self._dlp_client: 247 # DlpServiceClientにはclose()メソッドがある場合がある 248 if hasattr(self._dlp_client, 'transport') and hasattr(self._dlp_client.transport, 'close'): 249 self._dlp_client.transport.close() 250 except Exception as e: 251 logger.warning(f"Error closing DLP client: {e}") 252 finally: 253 self._dlp_client = None 254 self._enabled = False 255 logger.debug("GCPSecurityClient closed")
クライアントリソースをクローズ
Note: gRPCクライアントのクローズ時に例外が発生しても リソースを確実にクリーンアップします。
62class LibreChatAuthClient: 63 """ 64 LibreChat Auth API クライアント 65 66 認証API(ユーザー管理、MCP OAuthトークン)へのアクセスを提供。 67 X-Internal-API-Key認証を使用します。 68 69 Attributes: 70 config: LibreChatAuthConfig設定 71 72 Example: 73 ```python 74 config = LibreChatAuthConfig.create( 75 base_url="https://auth.sbtestfeature.agenticstar.tm.softbank.jp", 76 api_key="xxx", 77 ) 78 async with LibreChatAuthClient(config) as client: 79 users = await client.get_users() 80 ``` 81 """ 82 83 def __init__(self, config: LibreChatAuthConfig): 84 """ 85 LibreChat Auth Clientを初期化 86 87 Args: 88 config: LibreChatAuthConfig インスタンス 89 90 Raises: 91 AuthConfigError: 設定が不正な場合 92 """ 93 if not config.base_url: 94 raise AuthConfigError("base_url is required") 95 96 if not config.api_key: 97 raise AuthConfigError("api_key is required") 98 99 self._config = config 100 self._http_client: Optional[httpx.AsyncClient] = None 101 102 logger.info( 103 f"LibreChatAuthClient initialized - base_url={config.base_url[:30]}..." 104 ) 105 106 async def __aenter__(self) -> "LibreChatAuthClient": 107 """Context Manager: 開始時の処理""" 108 return self 109 110 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 111 """Context Manager: 終了時にリソースをクローズ""" 112 await self.close() 113 114 def _get_http_client(self) -> httpx.AsyncClient: 115 """HTTP clientを取得(遅延初期化)""" 116 if self._http_client is None or self._http_client.is_closed: 117 self._http_client = httpx.AsyncClient(timeout=self._config.timeout) 118 return self._http_client 119 120 def _get_headers(self) -> Dict[str, str]: 121 """共通ヘッダーを取得""" 122 return { 123 "X-Internal-API-Key": self._config.api_key, 124 "Content-Type": "application/json", 125 } 126 127 async def _request( 128 self, 129 method: str, 130 endpoint: str, 131 params: Optional[Dict[str, Any]] = None, 132 json_body: Optional[Dict[str, Any]] = None, 133 ) -> Dict[str, Any]: 134 """HTTPリクエスト実行 135 136 Args: 137 method: HTTPメソッド (GET, POST, etc.) 138 endpoint: APIエンドポイント(baseURLからの相対パス) 139 params: クエリパラメータ 140 json_body: JSONリクエストボディ 141 142 Returns: 143 JSONレスポンス 144 145 Raises: 146 AuthUnauthorizedError: 認証失敗 (401) 147 AuthNotFoundError: リソース未発見 (404) 148 AuthRateLimitError: レート制限 (429) 149 AuthAPIError: その他のAPIエラー 150 """ 151 url = f"{self._config.base_url.rstrip('/')}{endpoint}" 152 client = self._get_http_client() 153 headers = self._get_headers() 154 155 try: 156 if method.upper() == "GET": 157 response = await client.get(url, headers=headers, params=params) 158 elif method.upper() == "POST": 159 response = await client.post(url, headers=headers, json=json_body) 160 else: 161 raise AuthAPIError(f"Unsupported HTTP method: {method}") 162 163 # ステータスコード別処理 164 if response.status_code == 200: 165 return response.json() 166 elif response.status_code == 401: 167 raise AuthUnauthorizedError() 168 elif response.status_code == 404: 169 raise AuthNotFoundError() 170 elif response.status_code == 429: 171 raise AuthRateLimitError() 172 else: 173 error_msg = f"Status {response.status_code}: {response.text}" 174 raise AuthAPIError(error_msg, response.status_code) 175 176 except httpx.RequestError as e: 177 raise AuthAPIError(f"Request failed: {e}") 178 179 # ============================================================ 180 # ユーザー一覧取得 181 # ============================================================ 182 183 async def get_users( 184 self, 185 page: int = 1, 186 limit: int = 20, 187 search: Optional[str] = None, 188 is_approved: Optional[str] = None, 189 job: Optional[str] = None, 190 organization: Optional[str] = None, 191 sort_by: Optional[str] = None, 192 order: Optional[str] = None, 193 include_last_login: bool = False, 194 ) -> GetUsersResult: 195 """ 196 ユーザー一覧を取得 197 198 Args: 199 page: ページ番号(デフォルト: 1) 200 limit: 1ページあたりの件数(デフォルト: 20) 201 search: 検索キーワード(email, username, displayNameで検索) 202 is_approved: 承認状態フィルター ('true', 'false', 'all') 203 job: 職種でフィルター 204 organization: 組織でフィルター 205 sort_by: ソート項目 ('created_timestamp', 'email', 'username') 206 order: ソート順 ('asc', 'desc') 207 include_last_login: 最終ログイン時刻を取得するか 208 209 Returns: 210 GetUsersResult: ユーザー一覧結果 211 212 Example: 213 ```python 214 result = await client.get_users( 215 page=1, 216 limit=20, 217 is_approved="true", 218 sort_by="created_timestamp", 219 order="desc", 220 ) 221 if result.success: 222 for user in result.users: 223 print(f"User: {user.email}") 224 ``` 225 """ 226 try: 227 params: Dict[str, Any] = { 228 "page": page, 229 "limit": limit, 230 } 231 232 if search: 233 params["search"] = search 234 if is_approved: 235 params["is_approved"] = is_approved 236 if job: 237 params["job"] = job 238 if organization: 239 params["organization"] = organization 240 if sort_by: 241 params["sort_by"] = sort_by 242 if order: 243 params["order"] = order 244 if include_last_login: 245 params["include_last_login"] = "true" 246 247 response = await self._request( 248 "GET", "/api/v1/admin/users", params=params 249 ) 250 251 # レスポンス解析 252 data = response.get("data", {}) 253 items_raw = data.get("items", []) 254 pagination_raw = data.get("pagination", {}) 255 256 users = [ApiUser.model_validate(item) for item in items_raw] 257 pagination = ( 258 UserPagination.model_validate(pagination_raw) 259 if pagination_raw 260 else None 261 ) 262 263 return GetUsersResult( 264 success=True, 265 users=users, 266 pagination=pagination, 267 ) 268 269 except (AuthUnauthorizedError, AuthNotFoundError, AuthRateLimitError, AuthAPIError) as e: 270 logger.error(f"Failed to get users: {e}") 271 return GetUsersResult( 272 success=False, 273 error=str(e), 274 error_code=getattr(e, "error_code", "API_ERROR"), 275 ) 276 except Exception as e: 277 logger.error(f"Unexpected error getting users: {e}") 278 return GetUsersResult( 279 success=False, 280 error=str(e), 281 error_code="UNKNOWN_ERROR", 282 ) 283 284 # ============================================================ 285 # ユーザー詳細取得 286 # ============================================================ 287 288 async def get_user(self, user_id: str) -> GetUserResult: 289 """ 290 ユーザー詳細を取得 291 292 Args: 293 user_id: ユーザーID 294 295 Returns: 296 GetUserResult: ユーザー詳細結果 297 298 Example: 299 ```python 300 result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4") 301 if result.success: 302 print(f"User: {result.user.email}") 303 if result.devices: 304 print(f"Devices: {len(result.devices)}") 305 ``` 306 """ 307 if not user_id: 308 return GetUserResult( 309 success=False, 310 error="user_id is required", 311 error_code="INVALID_PARAMETER", 312 ) 313 314 try: 315 response = await self._request( 316 "GET", f"/api/v1/admin/users/{user_id}" 317 ) 318 319 user_raw = response.get("user", {}) 320 devices_raw = response.get("devices", []) 321 login_history_raw = response.get("loginHistory", []) 322 323 user = ApiUser.model_validate(user_raw) 324 # 空配列はそのまま空配列として返す(Noneに変換しない) 325 devices = [DeviceInfo.model_validate(d) for d in devices_raw] 326 login_history = [LoginHistoryEntry.model_validate(h) for h in login_history_raw] 327 328 # ソフトデリート情報 329 is_deleted = user_raw.get("isDeleted", False) 330 deleted_at = user_raw.get("deletedAt") 331 deleted_by = user_raw.get("deletedBy") 332 deletion_reason = user_raw.get("deletionReason") 333 334 return GetUserResult( 335 success=True, 336 user=user, 337 devices=devices, 338 login_history=login_history, 339 is_deleted=is_deleted, 340 deleted_at=deleted_at, 341 deleted_by=deleted_by, 342 deletion_reason=deletion_reason, 343 ) 344 345 except AuthNotFoundError: 346 return GetUserResult( 347 success=False, 348 error=f"User not found: {user_id}", 349 error_code="NOT_FOUND", 350 ) 351 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 352 logger.error(f"Failed to get user {user_id}: {e}") 353 return GetUserResult( 354 success=False, 355 error=str(e), 356 error_code=getattr(e, "error_code", "API_ERROR"), 357 ) 358 except Exception as e: 359 logger.error(f"Unexpected error getting user {user_id}: {e}") 360 return GetUserResult( 361 success=False, 362 error=str(e), 363 error_code="UNKNOWN_ERROR", 364 ) 365 366 # ============================================================ 367 # MCP OAuth トークン取得 368 # ============================================================ 369 370 async def get_mcp_tokens( 371 self, 372 user_id: str, 373 providers: Optional[List[OAuthProviderName]] = None, 374 ) -> GetMCPTokensResult: 375 """ 376 MCP用OAuthトークンを取得 377 378 Args: 379 user_id: ユーザーID 380 providers: 取得するプロバイダー(省略時は全プロバイダー) 381 例: ["github", "slack", "google", "office365"] 382 383 Returns: 384 GetMCPTokensResult: MCPトークン取得結果 385 386 Example: 387 ```python 388 result = await client.get_mcp_tokens( 389 "user-id-123", 390 providers=["github", "slack"], 391 ) 392 if result.success: 393 for provider, token in result.tokens.items(): 394 print(f"{provider}: expires at {token.expires_at}") 395 if result.errors: 396 for provider, err in result.errors.items(): 397 print(f"{provider} error: {err.message}") 398 ``` 399 """ 400 if not user_id: 401 return GetMCPTokensResult( 402 success=False, 403 error="user_id is required", 404 error_code="INVALID_PARAMETER", 405 ) 406 407 try: 408 body: Dict[str, Any] = {"user_id": user_id} 409 if providers: 410 body["providers"] = providers 411 412 response = await self._request( 413 "POST", "/api/v1/oauth/mcp/token", json_body=body 414 ) 415 416 tokens_raw = response.get("tokens", {}) 417 errors_raw = response.get("errors", {}) 418 419 # トークン変換(expires_at: string -> datetime) 420 tokens: Dict[str, MCPTokenInfo] = {} 421 token_errors: Dict[str, MCPTokenError] = {} 422 423 for provider, token_data in tokens_raw.items(): 424 expires_at_str = token_data.get("expires_at") 425 426 # expires_atが欠損または不正な場合はエラーとして扱う 427 if not expires_at_str: 428 token_errors[provider] = MCPTokenError( 429 code="INVALID_TOKEN_DATA", 430 message="expires_at is missing", 431 ) 432 continue 433 434 try: 435 expires_at = datetime.fromisoformat( 436 expires_at_str.replace("Z", "+00:00") 437 ) 438 except (ValueError, TypeError) as e: 439 token_errors[provider] = MCPTokenError( 440 code="INVALID_EXPIRES_AT", 441 message=f"Invalid expires_at format: {expires_at_str}", 442 ) 443 continue 444 445 tokens[provider] = MCPTokenInfo( 446 access_token=token_data.get("access_token", ""), 447 token_type=token_data.get("token_type", "Bearer"), 448 expires_at=expires_at, 449 scopes=token_data.get("scopes", []), 450 ) 451 452 # APIからのエラーを追加 453 for provider, err in errors_raw.items(): 454 token_errors[provider] = MCPTokenError( 455 code=err.get("code", "UNKNOWN"), 456 message=err.get("message", "Unknown error"), 457 ) 458 459 # エラーがあればerrorsに設定 460 final_errors = token_errors if token_errors else None 461 462 return GetMCPTokensResult( 463 success=True, 464 tokens=tokens, 465 errors=final_errors, 466 ) 467 468 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 469 logger.error(f"Failed to get MCP tokens for {user_id}: {e}") 470 return GetMCPTokensResult( 471 success=False, 472 error=str(e), 473 error_code=getattr(e, "error_code", "API_ERROR"), 474 ) 475 except Exception as e: 476 logger.error(f"Unexpected error getting MCP tokens for {user_id}: {e}") 477 return GetMCPTokensResult( 478 success=False, 479 error=str(e), 480 error_code="UNKNOWN_ERROR", 481 ) 482 483 # ============================================================ 484 # リソースクローズ 485 # ============================================================ 486 487 async def close(self) -> None: 488 """クライアントリソースをクローズ""" 489 try: 490 if self._http_client and not self._http_client.is_closed: 491 await self._http_client.aclose() 492 except Exception as e: 493 logger.warning(f"Error closing HTTP client: {e}") 494 finally: 495 self._http_client = None 496 logger.debug("LibreChatAuthClient closed")
LibreChat Auth API クライアント
認証API(ユーザー管理、MCP OAuthトークン)へのアクセスを提供。 X-Internal-API-Key認証を使用します。
Attributes: config: LibreChatAuthConfig設定
Example:
config = LibreChatAuthConfig.create(
base_url="https://auth.sbtestfeature.agenticstar.tm.softbank.jp",
api_key="xxx",
)
async with LibreChatAuthClient(config) as client:
users = await client.get_users()
83 def __init__(self, config: LibreChatAuthConfig): 84 """ 85 LibreChat Auth Clientを初期化 86 87 Args: 88 config: LibreChatAuthConfig インスタンス 89 90 Raises: 91 AuthConfigError: 設定が不正な場合 92 """ 93 if not config.base_url: 94 raise AuthConfigError("base_url is required") 95 96 if not config.api_key: 97 raise AuthConfigError("api_key is required") 98 99 self._config = config 100 self._http_client: Optional[httpx.AsyncClient] = None 101 102 logger.info( 103 f"LibreChatAuthClient initialized - base_url={config.base_url[:30]}..." 104 )
LibreChat Auth Clientを初期化
Args: config: LibreChatAuthConfig インスタンス
Raises: AuthConfigError: 設定が不正な場合
183 async def get_users( 184 self, 185 page: int = 1, 186 limit: int = 20, 187 search: Optional[str] = None, 188 is_approved: Optional[str] = None, 189 job: Optional[str] = None, 190 organization: Optional[str] = None, 191 sort_by: Optional[str] = None, 192 order: Optional[str] = None, 193 include_last_login: bool = False, 194 ) -> GetUsersResult: 195 """ 196 ユーザー一覧を取得 197 198 Args: 199 page: ページ番号(デフォルト: 1) 200 limit: 1ページあたりの件数(デフォルト: 20) 201 search: 検索キーワード(email, username, displayNameで検索) 202 is_approved: 承認状態フィルター ('true', 'false', 'all') 203 job: 職種でフィルター 204 organization: 組織でフィルター 205 sort_by: ソート項目 ('created_timestamp', 'email', 'username') 206 order: ソート順 ('asc', 'desc') 207 include_last_login: 最終ログイン時刻を取得するか 208 209 Returns: 210 GetUsersResult: ユーザー一覧結果 211 212 Example: 213 ```python 214 result = await client.get_users( 215 page=1, 216 limit=20, 217 is_approved="true", 218 sort_by="created_timestamp", 219 order="desc", 220 ) 221 if result.success: 222 for user in result.users: 223 print(f"User: {user.email}") 224 ``` 225 """ 226 try: 227 params: Dict[str, Any] = { 228 "page": page, 229 "limit": limit, 230 } 231 232 if search: 233 params["search"] = search 234 if is_approved: 235 params["is_approved"] = is_approved 236 if job: 237 params["job"] = job 238 if organization: 239 params["organization"] = organization 240 if sort_by: 241 params["sort_by"] = sort_by 242 if order: 243 params["order"] = order 244 if include_last_login: 245 params["include_last_login"] = "true" 246 247 response = await self._request( 248 "GET", "/api/v1/admin/users", params=params 249 ) 250 251 # レスポンス解析 252 data = response.get("data", {}) 253 items_raw = data.get("items", []) 254 pagination_raw = data.get("pagination", {}) 255 256 users = [ApiUser.model_validate(item) for item in items_raw] 257 pagination = ( 258 UserPagination.model_validate(pagination_raw) 259 if pagination_raw 260 else None 261 ) 262 263 return GetUsersResult( 264 success=True, 265 users=users, 266 pagination=pagination, 267 ) 268 269 except (AuthUnauthorizedError, AuthNotFoundError, AuthRateLimitError, AuthAPIError) as e: 270 logger.error(f"Failed to get users: {e}") 271 return GetUsersResult( 272 success=False, 273 error=str(e), 274 error_code=getattr(e, "error_code", "API_ERROR"), 275 ) 276 except Exception as e: 277 logger.error(f"Unexpected error getting users: {e}") 278 return GetUsersResult( 279 success=False, 280 error=str(e), 281 error_code="UNKNOWN_ERROR", 282 )
ユーザー一覧を取得
Args: page: ページ番号(デフォルト: 1) limit: 1ページあたりの件数(デフォルト: 20) search: 検索キーワード(email, username, displayNameで検索) is_approved: 承認状態フィルター ('true', 'false', 'all') job: 職種でフィルター organization: 組織でフィルター sort_by: ソート項目 ('created_timestamp', 'email', 'username') order: ソート順 ('asc', 'desc') include_last_login: 最終ログイン時刻を取得するか
Returns: GetUsersResult: ユーザー一覧結果
Example:
result = await client.get_users(
page=1,
limit=20,
is_approved="true",
sort_by="created_timestamp",
order="desc",
)
if result.success:
for user in result.users:
print(f"User: {user.email}")
288 async def get_user(self, user_id: str) -> GetUserResult: 289 """ 290 ユーザー詳細を取得 291 292 Args: 293 user_id: ユーザーID 294 295 Returns: 296 GetUserResult: ユーザー詳細結果 297 298 Example: 299 ```python 300 result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4") 301 if result.success: 302 print(f"User: {result.user.email}") 303 if result.devices: 304 print(f"Devices: {len(result.devices)}") 305 ``` 306 """ 307 if not user_id: 308 return GetUserResult( 309 success=False, 310 error="user_id is required", 311 error_code="INVALID_PARAMETER", 312 ) 313 314 try: 315 response = await self._request( 316 "GET", f"/api/v1/admin/users/{user_id}" 317 ) 318 319 user_raw = response.get("user", {}) 320 devices_raw = response.get("devices", []) 321 login_history_raw = response.get("loginHistory", []) 322 323 user = ApiUser.model_validate(user_raw) 324 # 空配列はそのまま空配列として返す(Noneに変換しない) 325 devices = [DeviceInfo.model_validate(d) for d in devices_raw] 326 login_history = [LoginHistoryEntry.model_validate(h) for h in login_history_raw] 327 328 # ソフトデリート情報 329 is_deleted = user_raw.get("isDeleted", False) 330 deleted_at = user_raw.get("deletedAt") 331 deleted_by = user_raw.get("deletedBy") 332 deletion_reason = user_raw.get("deletionReason") 333 334 return GetUserResult( 335 success=True, 336 user=user, 337 devices=devices, 338 login_history=login_history, 339 is_deleted=is_deleted, 340 deleted_at=deleted_at, 341 deleted_by=deleted_by, 342 deletion_reason=deletion_reason, 343 ) 344 345 except AuthNotFoundError: 346 return GetUserResult( 347 success=False, 348 error=f"User not found: {user_id}", 349 error_code="NOT_FOUND", 350 ) 351 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 352 logger.error(f"Failed to get user {user_id}: {e}") 353 return GetUserResult( 354 success=False, 355 error=str(e), 356 error_code=getattr(e, "error_code", "API_ERROR"), 357 ) 358 except Exception as e: 359 logger.error(f"Unexpected error getting user {user_id}: {e}") 360 return GetUserResult( 361 success=False, 362 error=str(e), 363 error_code="UNKNOWN_ERROR", 364 )
ユーザー詳細を取得
Args: user_id: ユーザーID
Returns: GetUserResult: ユーザー詳細結果
Example:
result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4")
if result.success:
print(f"User: {result.user.email}")
if result.devices:
print(f"Devices: {len(result.devices)}")
370 async def get_mcp_tokens( 371 self, 372 user_id: str, 373 providers: Optional[List[OAuthProviderName]] = None, 374 ) -> GetMCPTokensResult: 375 """ 376 MCP用OAuthトークンを取得 377 378 Args: 379 user_id: ユーザーID 380 providers: 取得するプロバイダー(省略時は全プロバイダー) 381 例: ["github", "slack", "google", "office365"] 382 383 Returns: 384 GetMCPTokensResult: MCPトークン取得結果 385 386 Example: 387 ```python 388 result = await client.get_mcp_tokens( 389 "user-id-123", 390 providers=["github", "slack"], 391 ) 392 if result.success: 393 for provider, token in result.tokens.items(): 394 print(f"{provider}: expires at {token.expires_at}") 395 if result.errors: 396 for provider, err in result.errors.items(): 397 print(f"{provider} error: {err.message}") 398 ``` 399 """ 400 if not user_id: 401 return GetMCPTokensResult( 402 success=False, 403 error="user_id is required", 404 error_code="INVALID_PARAMETER", 405 ) 406 407 try: 408 body: Dict[str, Any] = {"user_id": user_id} 409 if providers: 410 body["providers"] = providers 411 412 response = await self._request( 413 "POST", "/api/v1/oauth/mcp/token", json_body=body 414 ) 415 416 tokens_raw = response.get("tokens", {}) 417 errors_raw = response.get("errors", {}) 418 419 # トークン変換(expires_at: string -> datetime) 420 tokens: Dict[str, MCPTokenInfo] = {} 421 token_errors: Dict[str, MCPTokenError] = {} 422 423 for provider, token_data in tokens_raw.items(): 424 expires_at_str = token_data.get("expires_at") 425 426 # expires_atが欠損または不正な場合はエラーとして扱う 427 if not expires_at_str: 428 token_errors[provider] = MCPTokenError( 429 code="INVALID_TOKEN_DATA", 430 message="expires_at is missing", 431 ) 432 continue 433 434 try: 435 expires_at = datetime.fromisoformat( 436 expires_at_str.replace("Z", "+00:00") 437 ) 438 except (ValueError, TypeError) as e: 439 token_errors[provider] = MCPTokenError( 440 code="INVALID_EXPIRES_AT", 441 message=f"Invalid expires_at format: {expires_at_str}", 442 ) 443 continue 444 445 tokens[provider] = MCPTokenInfo( 446 access_token=token_data.get("access_token", ""), 447 token_type=token_data.get("token_type", "Bearer"), 448 expires_at=expires_at, 449 scopes=token_data.get("scopes", []), 450 ) 451 452 # APIからのエラーを追加 453 for provider, err in errors_raw.items(): 454 token_errors[provider] = MCPTokenError( 455 code=err.get("code", "UNKNOWN"), 456 message=err.get("message", "Unknown error"), 457 ) 458 459 # エラーがあればerrorsに設定 460 final_errors = token_errors if token_errors else None 461 462 return GetMCPTokensResult( 463 success=True, 464 tokens=tokens, 465 errors=final_errors, 466 ) 467 468 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 469 logger.error(f"Failed to get MCP tokens for {user_id}: {e}") 470 return GetMCPTokensResult( 471 success=False, 472 error=str(e), 473 error_code=getattr(e, "error_code", "API_ERROR"), 474 ) 475 except Exception as e: 476 logger.error(f"Unexpected error getting MCP tokens for {user_id}: {e}") 477 return GetMCPTokensResult( 478 success=False, 479 error=str(e), 480 error_code="UNKNOWN_ERROR", 481 )
MCP用OAuthトークンを取得
Args: user_id: ユーザーID providers: 取得するプロバイダー(省略時は全プロバイダー) 例: ["github", "slack", "google", "office365"]
Returns: GetMCPTokensResult: MCPトークン取得結果
Example:
result = await client.get_mcp_tokens(
"user-id-123",
providers=["github", "slack"],
)
if result.success:
for provider, token in result.tokens.items():
print(f"{provider}: expires at {token.expires_at}")
if result.errors:
for provider, err in result.errors.items():
print(f"{provider} error: {err.message}")
487 async def close(self) -> None: 488 """クライアントリソースをクローズ""" 489 try: 490 if self._http_client and not self._http_client.is_closed: 491 await self._http_client.aclose() 492 except Exception as e: 493 logger.warning(f"Error closing HTTP client: {e}") 494 finally: 495 self._http_client = None 496 logger.debug("LibreChatAuthClient closed")
クライアントリソースをクローズ
17@dataclass 18class LibreChatAuthConfig: 19 """LibreChat Auth API設定 20 21 Args: 22 base_url: APIベースURL (例: "https://auth.sbtestfeature.agenticstar.tm.softbank.jp") 23 api_key: X-Internal-API-Key認証用APIキー 24 timeout: HTTPリクエストタイムアウト秒数 25 26 Note: 27 api_keyはrepr=Falseでログ出力から除外されます。 28 29 Example: 30 ```python 31 # コードから直接指定 32 config = LibreChatAuthConfig.create( 33 base_url="https://auth.example.com", 34 api_key="your-api-key", 35 ) 36 37 # config.tomlから読み込み 38 config = LibreChatAuthConfig.from_config() 39 ``` 40 """ 41 42 # 必須設定 43 base_url: str 44 api_key: str = field(repr=False) 45 46 # オプション設定 47 timeout: float = 30.0 48 49 def __str__(self) -> str: 50 """文字列表現(APIキーをマスク)""" 51 return f"LibreChatAuthConfig(base_url={self.base_url!r}, api_key=***)" 52 53 @classmethod 54 def create( 55 cls, 56 base_url: str, 57 api_key: str, 58 timeout: float = 30.0, 59 ) -> "LibreChatAuthConfig": 60 """コードから直接設定を指定して作成 61 62 Args: 63 base_url: APIベースURL 64 api_key: APIキー 65 timeout: タイムアウト秒数 66 67 Returns: 68 LibreChatAuthConfig インスタンス 69 70 Example: 71 ```python 72 config = LibreChatAuthConfig.create( 73 base_url="https://auth.example.com", 74 api_key="your-api-key", 75 ) 76 ``` 77 """ 78 return cls( 79 base_url=base_url, 80 api_key=api_key, 81 timeout=timeout, 82 ) 83 84 @classmethod 85 def from_config( 86 cls, 87 config_path: Optional[Path] = None, 88 ) -> "LibreChatAuthConfig": 89 """config.tomlから設定を読み込む 90 91 Args: 92 config_path: config.tomlのパス(省略時は自動検出) 93 94 Returns: 95 LibreChatAuthConfig インスタンス 96 97 Raises: 98 AuthConfigError: 設定ファイルの読み込みに失敗した場合 99 ValueError: 必須設定が見つからない場合 100 """ 101 toml_config = cls._load_toml_config(config_path) 102 103 base_url = toml_config.get("base_url") 104 api_key = toml_config.get("api_key") 105 timeout = toml_config.get("timeout", 30.0) 106 107 # バリデーション 108 if not base_url: 109 raise ValueError( 110 "base_url is required. Configure [auth.librechat] base_url in config.toml" 111 ) 112 113 if not api_key: 114 raise ValueError( 115 "api_key is required. Configure [auth.librechat] api_key in config.toml" 116 ) 117 118 return cls( 119 base_url=base_url, 120 api_key=api_key, 121 timeout=float(timeout), 122 ) 123 124 @staticmethod 125 def _load_toml_config(config_path: Optional[Path] = None) -> Dict[str, Any]: 126 """config.tomlから[auth.librechat]セクションを読み込む 127 128 Raises: 129 AuthConfigError: ファイル読み込みやパースに失敗した場合 130 """ 131 import tomllib 132 133 # パスが指定されていない場合は自動検出 134 if config_path is None: 135 # 優先順位: カレントディレクトリ > プロジェクトルート 136 candidates = [ 137 Path.cwd() / "config.toml", 138 Path(__file__).parent.parent.parent.parent.parent / "config.toml", 139 ] 140 for candidate in candidates: 141 if candidate.exists(): 142 config_path = candidate 143 break 144 145 if config_path is None or not config_path.exists(): 146 raise AuthConfigError( 147 f"config.toml not found. Searched: {[str(c) for c in candidates]}" 148 ) 149 150 try: 151 with open(config_path, "rb") as f: 152 config = tomllib.load(f) 153 except tomllib.TOMLDecodeError as e: 154 raise AuthConfigError(f"Invalid TOML syntax in {config_path}: {e}") 155 except PermissionError as e: 156 raise AuthConfigError(f"Permission denied reading {config_path}: {e}") 157 except Exception as e: 158 raise AuthConfigError(f"Failed to read {config_path}: {e}") 159 160 auth_librechat = config.get("auth", {}).get("librechat") 161 if auth_librechat is None: 162 raise AuthConfigError( 163 f"[auth.librechat] section not found in {config_path}" 164 ) 165 166 return auth_librechat
LibreChat Auth API設定
Args: base_url: APIベースURL (例: "https://auth.sbtestfeature.agenticstar.tm.softbank.jp") api_key: X-Internal-API-Key認証用APIキー timeout: HTTPリクエストタイムアウト秒数
Note: api_keyはrepr=Falseでログ出力から除外されます。
Example:
# コードから直接指定
config = LibreChatAuthConfig.create(
base_url="https://auth.example.com",
api_key="your-api-key",
)
# config.tomlから読み込み
config = LibreChatAuthConfig.from_config()
53 @classmethod 54 def create( 55 cls, 56 base_url: str, 57 api_key: str, 58 timeout: float = 30.0, 59 ) -> "LibreChatAuthConfig": 60 """コードから直接設定を指定して作成 61 62 Args: 63 base_url: APIベースURL 64 api_key: APIキー 65 timeout: タイムアウト秒数 66 67 Returns: 68 LibreChatAuthConfig インスタンス 69 70 Example: 71 ```python 72 config = LibreChatAuthConfig.create( 73 base_url="https://auth.example.com", 74 api_key="your-api-key", 75 ) 76 ``` 77 """ 78 return cls( 79 base_url=base_url, 80 api_key=api_key, 81 timeout=timeout, 82 )
コードから直接設定を指定して作成
Args: base_url: APIベースURL api_key: APIキー timeout: タイムアウト秒数
Returns: LibreChatAuthConfig インスタンス
Example:
config = LibreChatAuthConfig.create(
base_url="https://auth.example.com",
api_key="your-api-key",
)
84 @classmethod 85 def from_config( 86 cls, 87 config_path: Optional[Path] = None, 88 ) -> "LibreChatAuthConfig": 89 """config.tomlから設定を読み込む 90 91 Args: 92 config_path: config.tomlのパス(省略時は自動検出) 93 94 Returns: 95 LibreChatAuthConfig インスタンス 96 97 Raises: 98 AuthConfigError: 設定ファイルの読み込みに失敗した場合 99 ValueError: 必須設定が見つからない場合 100 """ 101 toml_config = cls._load_toml_config(config_path) 102 103 base_url = toml_config.get("base_url") 104 api_key = toml_config.get("api_key") 105 timeout = toml_config.get("timeout", 30.0) 106 107 # バリデーション 108 if not base_url: 109 raise ValueError( 110 "base_url is required. Configure [auth.librechat] base_url in config.toml" 111 ) 112 113 if not api_key: 114 raise ValueError( 115 "api_key is required. Configure [auth.librechat] api_key in config.toml" 116 ) 117 118 return cls( 119 base_url=base_url, 120 api_key=api_key, 121 timeout=float(timeout), 122 )
config.tomlから設定を読み込む
Args: config_path: config.tomlのパス(省略時は自動検出)
Returns: LibreChatAuthConfig インスタンス
Raises: AuthConfigError: 設定ファイルの読み込みに失敗した場合 ValueError: 必須設定が見つからない場合
Auth操作の基底エラー
16class AuthConfigError(AuthError): 17 """Auth設定エラー 18 19 Raises: 20 設定が不正または不足している場合 21 """ 22 23 pass
Auth設定エラー
Raises: 設定が不正または不足している場合
26class AuthAPIError(AuthError): 27 """Auth API呼び出しエラー 28 29 Attributes: 30 status_code: HTTPステータスコード 31 error_code: エラーコード文字列 32 """ 33 34 def __init__( 35 self, 36 message: str, 37 status_code: Optional[int] = None, 38 error_code: str = "API_ERROR", 39 ): 40 super().__init__(message) 41 self.status_code = status_code 42 self.error_code = error_code
Auth API呼び出しエラー
Attributes: status_code: HTTPステータスコード error_code: エラーコード文字列
56class AuthNotFoundError(AuthAPIError): 57 """リソース未発見エラー(404 Not Found) 58 59 Raises: 60 指定されたユーザーまたはリソースが存在しない場合 61 """ 62 63 def __init__(self, message: str = "Resource not found"): 64 super().__init__(message, status_code=404, error_code="NOT_FOUND")
リソース未発見エラー(404 Not Found)
Raises: 指定されたユーザーまたはリソースが存在しない場合
67class AuthRateLimitError(AuthAPIError): 68 """レート制限エラー(429 Too Many Requests) 69 70 Raises: 71 APIレート制限に達した場合 72 """ 73 74 def __init__(self, message: str = "Rate limit exceeded"): 75 super().__init__(message, status_code=429, error_code="RATE_LIMITED")
レート制限エラー(429 Too Many Requests)
Raises: APIレート制限に達した場合
39class ApiUser(BaseModel): 40 """ユーザー情報""" 41 42 id: str 43 username: Optional[str] = None 44 email: Optional[str] = None 45 first_name: Optional[str] = Field(None, alias="firstName") 46 last_name: Optional[str] = Field(None, alias="lastName") 47 display_name: Optional[str] = Field(None, alias="displayName") 48 email_verified: bool = Field(False, alias="emailVerified") 49 enabled: bool = True 50 created_timestamp: Optional[int] = Field(None, alias="created_timestamp") 51 attributes: Optional[Dict[str, Any]] = None 52 organization: Optional[str] = None 53 organization_label: Optional[str] = Field(None, alias="organizationLabel") 54 job: Optional[str] = None 55 job_label: Optional[str] = Field(None, alias="jobLabel") 56 bio: Optional[str] = None 57 last_login_at: Optional[str] = Field(None, alias="lastLoginAt") 58 59 model_config = {"populate_by_name": True} 60 61 # ID取得メソッド 62 def get_org_id(self) -> Optional[str]: 63 """組織IDを取得""" 64 return self.organization 65 66 def get_job_id(self) -> Optional[str]: 67 """職種IDを取得""" 68 return self.job 69 70 # ラベル(表示値)取得メソッド 71 def get_display_name(self) -> Optional[str]: 72 """表示名を取得""" 73 return self.display_name 74 75 def get_organization_label(self) -> Optional[str]: 76 """組織ラベルを取得""" 77 return self.organization_label 78 79 def get_job_label(self) -> Optional[str]: 80 """職種ラベルを取得""" 81 return self.job_label 82 83 def get_bio(self) -> Optional[str]: 84 """自己紹介を取得""" 85 return self.bio
ユーザー情報
75 def get_organization_label(self) -> Optional[str]: 76 """組織ラベルを取得""" 77 return self.organization_label
組織ラベルを取得
93class DeviceInfo(BaseModel): 94 """デバイス情報""" 95 96 id: Optional[str] = None 97 device_code: str 98 device_name: Optional[str] = None 99 device_type: Optional[str] = None 100 device_info: Optional[Dict[str, Any]] = None 101 is_active: Optional[bool] = None 102 created_at: Optional[str] = None 103 last_used_at: Optional[str] = None
デバイス情報
106class LoginHistoryEntry(BaseModel): 107 """ログイン履歴エントリ""" 108 109 timestamp: int 110 ip_address: str = Field(alias="ipAddress") 111 user_agent: str = Field(alias="userAgent") 112 platform: str 113 location: Optional[Dict[str, Optional[str]]] = None 114 success: bool 115 116 model_config = {"populate_by_name": True}
ログイン履歴エントリ
26class UserPagination(BaseModel): 27 """ページネーション情報""" 28 29 page: int 30 limit: int 31 total: int 32 total_pages: int = Field(alias="totalPages") 33 has_next: bool = Field(alias="hasNext") 34 has_previous: bool = Field(alias="hasPrevious") 35 36 model_config = {"populate_by_name": True}
ページネーション情報
124class MCPTokenInfo(BaseModel): 125 """MCPトークン情報""" 126 127 access_token: str 128 token_type: str 129 expires_at: datetime 130 scopes: List[str]
MCPトークン情報
133class MCPTokenError(BaseModel): 134 """MCPトークン取得エラー 135 136 Attributes: 137 code: エラーコード (CONNECTION_NOT_FOUND, TOKEN_EXPIRED, REFRESH_FAILED等) 138 message: エラーメッセージ 139 """ 140 141 code: str 142 message: str
MCPトークン取得エラー
Attributes: code: エラーコード (CONNECTION_NOT_FOUND, TOKEN_EXPIRED, REFRESH_FAILED等) message: エラーメッセージ
150class GetUsersResult(BaseModel): 151 """ユーザー一覧取得結果 152 153 Attributes: 154 success: 成功したかどうか 155 users: ユーザーリスト 156 pagination: ページネーション情報 157 error: エラーメッセージ(失敗時) 158 error_code: エラーコード(失敗時) 159 """ 160 161 success: bool 162 users: List[ApiUser] = Field(default_factory=list) 163 pagination: Optional[UserPagination] = None 164 error: Optional[str] = None 165 error_code: Optional[str] = None
ユーザー一覧取得結果
Attributes: success: 成功したかどうか users: ユーザーリスト pagination: ページネーション情報 error: エラーメッセージ(失敗時) error_code: エラーコード(失敗時)
168class GetUserResult(BaseModel): 169 """ユーザー詳細取得結果 170 171 Attributes: 172 success: 成功したかどうか 173 user: ユーザー情報 174 devices: デバイス情報リスト(空配列も保持) 175 login_history: ログイン履歴(空配列も保持) 176 error: エラーメッセージ(失敗時) 177 error_code: エラーコード(失敗時) 178 """ 179 180 success: bool 181 user: Optional[ApiUser] = None 182 devices: List[DeviceInfo] = Field(default_factory=list) 183 login_history: List[LoginHistoryEntry] = Field(default_factory=list) 184 is_deleted: bool = False 185 deleted_at: Optional[str] = None 186 deleted_by: Optional[str] = None 187 deletion_reason: Optional[str] = None 188 error: Optional[str] = None 189 error_code: Optional[str] = None
ユーザー詳細取得結果
Attributes: success: 成功したかどうか user: ユーザー情報 devices: デバイス情報リスト(空配列も保持) login_history: ログイン履歴(空配列も保持) error: エラーメッセージ(失敗時) error_code: エラーコード(失敗時)
192class GetMCPTokensResult(BaseModel): 193 """MCPトークン取得結果 194 195 Attributes: 196 success: 成功したかどうか 197 tokens: プロバイダー別トークン情報 198 errors: プロバイダー別エラー情報 199 error: 全体エラーメッセージ(失敗時) 200 error_code: 全体エラーコード(失敗時) 201 202 Example: 203 ```python 204 result = await client.get_mcp_tokens("user-id", ["github", "slack"]) 205 if result.success: 206 for provider, token in result.tokens.items(): 207 print(f"{provider}: expires at {token.expires_at}") 208 if result.errors: 209 for provider, err in result.errors.items(): 210 print(f"{provider} error: {err.message}") 211 ``` 212 """ 213 214 success: bool 215 tokens: Dict[str, MCPTokenInfo] = Field(default_factory=dict) 216 errors: Optional[Dict[str, MCPTokenError]] = None 217 error: Optional[str] = None 218 error_code: Optional[str] = None
MCPトークン取得結果
Attributes: success: 成功したかどうか tokens: プロバイダー別トークン情報 errors: プロバイダー別エラー情報 error: 全体エラーメッセージ(失敗時) error_code: 全体エラーコード(失敗時)
Example:
result = await client.get_mcp_tokens("user-id", ["github", "slack"])
if result.success:
for provider, token in result.tokens.items():
print(f"{provider}: expires at {token.expires_at}")
if result.errors:
for provider, err in result.errors.items():
print(f"{provider} error: {err.message}")