agenticstar_platform
AGENTICSTAR Platform SDK エンタープライズAIエージェント基盤のためのSDK
外部サービス(PostgreSQL、Qdrant、Azure Blob/S3/GCS、Content Safety/PII等)との通信機能を提供
1""" 2AGENTICSTAR Platform SDK 3エンタープライズAIエージェント基盤のためのSDK 4 5外部サービス(PostgreSQL、Qdrant、Azure Blob/S3/GCS、Content Safety/PII等)との通信機能を提供 6""" 7 8__version__ = "0.3.1" 9 10# Common utilities are internal - not exported as public API 11# Use agenticstar_platform.common for internal SDK development only 12 13# Database module 14from .db import PostgreSQLManager, PostgreSQLConfig, AzureADConfig, DataAccess, ConfigAccess 15 16# RAG module 17from .rag import EmbeddingGenerator, EmbeddingConfig, QdrantManager, QdrantConfig 18 19# Events module 20from .events import ( 21 EventType, 22 SubEventType, 23 StreamingEvent, 24 SequencedEvent, 25 ExecutionMessage, 26 EventEmitter, 27 EventHandler, 28 create_sse_handler, 29 create_json_handler, 30 # Event Handlers(マーケットプレイスUI連携用) 31 DatabaseEventHandler, 32 WebhookEventHandler, 33 CompositeEventHandler, 34 create_marketplace_handler, 35) 36 37# Memory module 38from .memory import ( 39 SemanticMemoryClient, 40 SemanticMemoryConfig, 41 LLMProviderConfig, 42) 43# Episodic memory is optional (requires graphiti-core[falkordb]) 44from .memory import EpisodicMemoryClient, EpisodicMemoryConfig, normalize_group_id 45 46# Storage module 47from .storage import ( 48 StorageProvider, 49 StorageConfig, 50 AzureBlobConfig, 51 S3Config, 52 GCSConfig, 53 AzureBlobStorageClient, 54 S3StorageClient, 55 GCSStorageClient, 56 UploadResult, 57 DownloadResult, 58 ObjectInfo, 59 ListResult, 60) 61 62# Security module 63from .security import ( 64 SecurityProvider, 65 ContentCategory, 66 PIICategory, 67 SecurityError, 68 SecurityConfigError, 69 SecurityAPIError, 70 ContentModerationResult, 71 PromptShieldResult, 72 PIIEntity, 73 PIIDetectionResult, 74 SecurityCheckResult, 75 AzureSecurityConfig, 76 AWSSecurityConfig, 77 GCPSecurityConfig, 78 AzureSecurityClient, 79 AWSSecurityClient, 80 GCPSecurityClient, 81) 82 83# Auth module 84from .auth import ( 85 AgenticStarAuthClient, 86 AgenticStarAuthConfig, 87 AuthError, 88 AuthConfigError, 89 AuthAPIError, 90 AuthUnauthorizedError, 91 AuthNotFoundError, 92 AuthRateLimitError, 93 ApiUser, 94 DeviceInfo, 95 LoginHistoryEntry, 96 UserPagination, 97 MCPTokenInfo, 98 MCPTokenError, 99 OAuthProviderName, 100 GetUsersResult, 101 GetUserResult, 102 GetMCPTokensResult, 103) 104 105__all__ = [ 106 # Database 107 "PostgreSQLManager", 108 "PostgreSQLConfig", 109 "AzureADConfig", 110 "DataAccess", 111 "ConfigAccess", 112 # RAG 113 "EmbeddingGenerator", 114 "EmbeddingConfig", 115 "QdrantManager", 116 "QdrantConfig", 117 # Events 118 "EventType", 119 "SubEventType", 120 "StreamingEvent", 121 "SequencedEvent", 122 "ExecutionMessage", 123 "EventEmitter", 124 "EventHandler", 125 "create_sse_handler", 126 "create_json_handler", 127 # Event Handlers(マーケットプレイスUI連携用) 128 "DatabaseEventHandler", 129 "WebhookEventHandler", 130 "CompositeEventHandler", 131 "create_marketplace_handler", 132 # Memory 133 "SemanticMemoryClient", 134 "SemanticMemoryConfig", 135 "EpisodicMemoryClient", 136 "EpisodicMemoryConfig", 137 "LLMProviderConfig", 138 "normalize_group_id", 139 # Storage 140 "StorageProvider", 141 "StorageConfig", 142 "AzureBlobConfig", 143 "S3Config", 144 "GCSConfig", 145 "AzureBlobStorageClient", 146 "S3StorageClient", 147 "GCSStorageClient", 148 "UploadResult", 149 "DownloadResult", 150 "ObjectInfo", 151 "ListResult", 152 # Security 153 "SecurityProvider", 154 "ContentCategory", 155 "PIICategory", 156 "SecurityError", 157 "SecurityConfigError", 158 "SecurityAPIError", 159 "ContentModerationResult", 160 "PromptShieldResult", 161 "PIIEntity", 162 "PIIDetectionResult", 163 "SecurityCheckResult", 164 "AzureSecurityConfig", 165 "AWSSecurityConfig", 166 "GCPSecurityConfig", 167 "AzureSecurityClient", 168 "AWSSecurityClient", 169 "GCPSecurityClient", 170 # Auth 171 "AgenticStarAuthClient", 172 "AgenticStarAuthConfig", 173 "AuthError", 174 "AuthConfigError", 175 "AuthAPIError", 176 "AuthUnauthorizedError", 177 "AuthNotFoundError", 178 "AuthRateLimitError", 179 "ApiUser", 180 "DeviceInfo", 181 "LoginHistoryEntry", 182 "UserPagination", 183 "MCPTokenInfo", 184 "MCPTokenError", 185 "OAuthProviderName", 186 "GetUsersResult", 187 "GetUserResult", 188 "GetMCPTokensResult", 189]
26class PostgreSQLManager: 27 """PostgreSQL接続とクエリ実行クラス 28 29 Azure AD認証またはユーザー名/パスワード認証でPostgreSQLに接続し、 30 クエリを実行します。接続プールを使用して効率的に接続を管理します。 31 32 Features: 33 - Azure AD認証対応(トークン自動更新) 34 - 標準認証(ユーザー名/パスワード)対応 35 - 接続プール管理 36 - 非同期コンテキストマネージャー対応 37 - SQL Injection防止(パラメータ化クエリ) 38 39 Example: 40 >>> from agenticstar_platform import PostgreSQLManager, PostgreSQLConfig 41 >>> 42 >>> # 設定を作成 43 >>> config = PostgreSQLConfig.from_toml("config.toml") 44 >>> 45 >>> # コンテキストマネージャーで使用(推奨) 46 >>> async with PostgreSQLManager(config) as db: 47 ... # SELECT実行 48 ... result = await db.fetch_all("SELECT * FROM users WHERE status = $1", ("active",)) 49 ... print(f"Active users: {len(result)}") 50 ... 51 ... # INSERT実行 52 ... await db.execute( 53 ... "INSERT INTO logs (message) VALUES ($1)", 54 ... ("Operation completed",) 55 ... ) 56 >>> 57 >>> # 明示的な初期化・クローズ 58 >>> db = PostgreSQLManager(config) 59 >>> await db.initialize() 60 >>> try: 61 ... result = await db.fetch_one("SELECT COUNT(*) as cnt FROM users") 62 ... print(f"Total users: {result['cnt']}") 63 ... finally: 64 ... await db.close() 65 """ 66 67 def __init__(self, config: PostgreSQLConfig): 68 self.config = config 69 self.pool: Optional[asyncpg.Pool] = None 70 self._access_token: Optional[str] = None 71 self._token_expiry: Optional[float] = None 72 self._token_refresh_lock: Optional[asyncio.Lock] = None 73 74 async def __aenter__(self) -> "PostgreSQLManager": 75 """Context Manager: 開始時に接続プールを初期化""" 76 await self.initialize() 77 return self 78 79 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 80 """Context Manager: 終了時に接続プールをクローズ""" 81 await self.close() 82 83 async def _get_azure_ad_token(self) -> str: 84 """Azure AD認証でアクセストークンを取得""" 85 logger.info("[IN] _get_azure_ad_token: Getting Azure AD token") 86 87 if not self.config.azure_ad: 88 logger.error("[ERROR] _get_azure_ad_token: Azure AD configuration is required") 89 raise ValueError("Azure AD configuration is required") 90 91 # トークンの有効期限をチェック(5分前に更新) 92 if ( 93 self._access_token 94 and self._token_expiry 95 and time.time() < self._token_expiry - 300 96 ): 97 logger.info("[OUT] _get_azure_ad_token: Using cached token") 98 return self._access_token 99 100 # トークン取得中の同期処理(複数API同時実行時の重複取得防止) 101 # Note: Lockはinitialize()で作成済み(イベントループ紐付けを保証) 102 async with self._token_refresh_lock: 103 # ロック取得後に再チェック(他のAPIコールが既に更新した可能性) 104 if ( 105 self._access_token 106 and self._token_expiry 107 and time.time() < self._token_expiry - 300 108 ): 109 return self._access_token 110 111 try: 112 credential = ClientSecretCredential( 113 tenant_id=self.config.azure_ad.tenant_id, 114 client_id=self.config.azure_ad.client_id, 115 client_secret=self.config.azure_ad.client_secret, 116 ) 117 118 # Try different scopes for Azure PostgreSQL Flexible Server 119 scopes_to_try = [ 120 "https://ossrdbms-aad.database.windows.net/.default", 121 "https://database.windows.net/.default", 122 "https://management.azure.com/.default", 123 ] 124 125 token = None 126 last_error = None 127 128 for scope in scopes_to_try: 129 try: 130 token = await asyncio.to_thread(credential.get_token, scope) 131 break 132 except Exception as e: 133 last_error = e 134 continue 135 136 if not token: 137 raise last_error or Exception("All token scopes failed") 138 139 self._access_token = token.token 140 self._token_expiry = token.expires_on 141 142 logger.info("[OUT] _get_azure_ad_token: Azure AD token acquired successfully") 143 return self._access_token 144 145 except Exception as e: 146 logger.error( 147 "[ERROR] _get_azure_ad_token: Failed to acquire Azure AD token - error_type=%s", 148 type(e).__name__, 149 ) 150 raise 151 152 async def _create_connection_params(self) -> Dict[str, Any]: 153 """接続パラメータを作成""" 154 logger.info("[IN] _create_connection_params: Creating database connection parameters") 155 params = { 156 "host": self.config.host, 157 "port": self.config.port, 158 "database": self.config.database, 159 "command_timeout": self.config.command_timeout, 160 } 161 162 # SSL設定(disable以外は設定する) 163 ssl_mode = getattr(self.config, 'ssl_mode', 'require') 164 if ssl_mode != "disable": 165 params["ssl"] = ssl_mode 166 167 if self.config.use_azure_ad: 168 token = await self._get_azure_ad_token() 169 170 # Use configured username if available, otherwise fallback to client_id 171 if self.config.azure_ad and self.config.azure_ad.username: 172 username = self.config.azure_ad.username 173 elif self.config.azure_ad: 174 username = self.config.azure_ad.client_id 175 else: 176 raise ValueError("Azure AD configuration is required when use_azure_ad is True") 177 178 params.update( 179 { 180 "user": username, 181 "password": token, 182 } 183 ) 184 else: 185 # 標準認証の場合(ユーザー名/パスワード) 186 if not self.config.username or not self.config.password: 187 raise ValueError("Username and password are required when use_azure_ad is False") 188 189 params.update( 190 { 191 "user": self.config.username, 192 "password": self.config.password, 193 } 194 ) 195 196 logger.info("[OUT] _create_connection_params: Connection parameters created successfully") 197 return params 198 199 async def _ensure_initialized(self) -> None: 200 """接続プールが初期化されていることを保証""" 201 if not self.pool: 202 await self.initialize() 203 204 async def initialize(self) -> None: 205 """データベース接続プールを初期化""" 206 logger.info("[IN] initialize: Initializing PostgreSQL connection pool") 207 try: 208 # トークンリフレッシュ用のLockを作成(イベントループ内で作成することで紐付けを保証) 209 if self._token_refresh_lock is None: 210 self._token_refresh_lock = asyncio.Lock() 211 212 logger.info("[IN] initialize.create_connection_params: Creating connection parameters") 213 connection_params = await self._create_connection_params() 214 logger.info("[OUT] initialize.create_connection_params: Connection parameters created") 215 216 logger.info( 217 "[IN] initialize.create_pool: Creating connection pool - min_size=%s, max_size=%s", 218 self.config.pool_min_size, 219 self.config.pool_max_size, 220 ) 221 self.pool = await asyncpg.create_pool( 222 min_size=self.config.pool_min_size, 223 max_size=self.config.pool_max_size, 224 **connection_params, 225 ) 226 logger.info("[OUT] initialize.create_pool: Connection pool created successfully") 227 logger.info("[OUT] initialize: PostgreSQL connection pool initialized successfully") 228 229 except Exception as e: 230 logger.error( 231 "[ERROR] initialize: PostgreSQL initialization failed - error_type=%s, error=%s", 232 type(e).__name__, 233 str(e), 234 ) 235 raise 236 237 async def close(self) -> None: 238 """接続プールを閉じる 239 240 Note: 241 例外が発生してもリソースを確実にクリーンアップします。 242 """ 243 logger.info("[IN] close: Closing PostgreSQL connection pool") 244 try: 245 if self.pool: 246 await self.pool.close() 247 logger.info("[OUT] close: PostgreSQL connection pool closed successfully") 248 else: 249 logger.info("[OUT] close: No connection pool to close") 250 except Exception as e: 251 logger.warning(f"Error closing PostgreSQL connection pool: {e}") 252 finally: 253 self.pool = None 254 self._access_token = None 255 self._token_expiry = None 256 257 async def execute_query( 258 self, query: str, params: tuple = () 259 ) -> Dict[str, Any]: 260 """ 261 汎用クエリ実行メソッド 262 263 SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 264 結果を統一フォーマットで返します。 265 266 Args: 267 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 268 params: クエリパラメータ(タプル) 269 270 Returns: 271 Dict[str, Any]: 実行結果 272 - success: bool - 実行成功/失敗 273 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) 274 - error: Optional[str] - エラーメッセージ(失敗時のみ) 275 - error_code: Optional[str] - エラーコード(失敗時のみ) 276 277 Example: 278 >>> # SELECTクエリ 279 >>> result = await db.execute_query( 280 ... "SELECT * FROM users WHERE created_at > $1", 281 ... (datetime(2024, 1, 1),) 282 ... ) 283 >>> if result["success"]: 284 ... for user in result["data"]: 285 ... print(user["name"]) 286 >>> 287 >>> # INSERT with RETURNING 288 >>> result = await db.execute_query( 289 ... "INSERT INTO users (name) VALUES ($1) RETURNING id", 290 ... ("John",) 291 ... ) 292 >>> if result["success"]: 293 ... new_id = result["data"][0]["id"] 294 """ 295 logger.info( 296 "[IN] execute_query: Executing SQL query - query_type=%s, params_count=%s", 297 (query.strip().upper().split()[0] if query.strip() else "EMPTY"), 298 len(params), 299 ) 300 try: 301 await self._ensure_initialized() 302 303 async with self.pool.acquire() as conn: 304 # クエリの種類を判定 305 normalized_query = query.strip().upper() 306 is_select = normalized_query.startswith("SELECT") 307 has_returning = "RETURNING" in normalized_query 308 309 if is_select or has_returning: 310 # SELECT文またはRETURNING句がある場合はfetchを使用 311 rows = await conn.fetch(query, *params) 312 # Rowオブジェクトを辞書に変換 313 result_data = [dict(row) for row in rows] 314 logger.info( 315 "[OUT] execute_query: Query executed successfully - result_count=%s", 316 len(result_data), 317 ) 318 return {"success": True, "data": result_data} 319 else: 320 # INSERT/UPDATE/DELETE等(RETURNING句なし)の場合はexecuteを使用 321 result = await conn.execute(query, *params) 322 logger.info("[OUT] execute_query: Query executed successfully") 323 return {"success": True, "data": {"result": result}} 324 325 except Exception as e: 326 logger.error( 327 "[ERROR] execute_query: Query execution failed - error_type=%s, error=%s", 328 type(e).__name__, 329 str(e), 330 ) 331 return { 332 "success": False, 333 "data": None, 334 "error": f"Database query error: {str(e)}", 335 "error_code": "DB_QUERY_ERROR", 336 } 337 338 async def fetch_one( 339 self, query: str, params: tuple = () 340 ) -> Optional[Dict[str, Any]]: 341 """ 342 単一行を取得 343 344 1行だけ返すことが期待されるクエリに使用します。 345 結果が複数行ある場合でも最初の1行のみ返します。 346 347 Args: 348 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 349 params: クエリパラメータ(タプル) 350 351 Returns: 352 Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone 353 354 Example: 355 >>> # ユーザーを1件取得 356 >>> user = await db.fetch_one( 357 ... "SELECT * FROM users WHERE id = $1", 358 ... (user_id,) 359 ... ) 360 >>> if user: 361 ... print(f"Found: {user['name']}") 362 ... else: 363 ... print("User not found") 364 >>> 365 >>> # 集計結果を取得 366 >>> count = await db.fetch_one("SELECT COUNT(*) as total FROM users") 367 >>> print(f"Total: {count['total']}") 368 """ 369 await self._ensure_initialized() 370 async with self.pool.acquire() as conn: 371 row = await conn.fetchrow(query, *params) 372 return dict(row) if row else None 373 374 async def fetch_all( 375 self, query: str, params: tuple = () 376 ) -> List[Dict[str, Any]]: 377 """ 378 全行を取得 379 380 複数行を返すSELECTクエリに使用します。 381 結果がない場合は空のリストを返します。 382 383 Args: 384 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 385 params: クエリパラメータ(タプル) 386 387 Returns: 388 List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式) 389 390 Example: 391 >>> # 全ユーザーを取得 392 >>> users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") 393 >>> for user in users: 394 ... print(f"{user['id']}: {user['name']}") 395 >>> 396 >>> # 条件付きで取得 397 >>> active_users = await db.fetch_all( 398 ... "SELECT * FROM users WHERE status = $1 AND role = $2", 399 ... ("active", "admin") 400 ... ) 401 >>> print(f"Found {len(active_users)} admin users") 402 """ 403 await self._ensure_initialized() 404 async with self.pool.acquire() as conn: 405 rows = await conn.fetch(query, *params) 406 return [dict(row) for row in rows] 407 408 async def execute(self, query: str, params: tuple = ()) -> str: 409 """ 410 実行のみ(結果を返さないクエリ用) 411 412 INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 413 DDL文(CREATE TABLE等)の実行にも使用できます。 414 415 Args: 416 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 417 params: クエリパラメータ(タプル) 418 419 Returns: 420 str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3") 421 422 Example: 423 >>> # データ挿入 424 >>> result = await db.execute( 425 ... "INSERT INTO logs (level, message) VALUES ($1, $2)", 426 ... ("INFO", "Operation completed") 427 ... ) 428 >>> print(result) # "INSERT 0 1" 429 >>> 430 >>> # データ更新 431 >>> result = await db.execute( 432 ... "UPDATE users SET status = $1 WHERE last_login < $2", 433 ... ("inactive", datetime(2024, 1, 1)) 434 ... ) 435 >>> print(result) # "UPDATE 15" 436 >>> 437 >>> # データ削除 438 >>> result = await db.execute( 439 ... "DELETE FROM sessions WHERE expired_at < $1", 440 ... (datetime.now(),) 441 ... ) 442 >>> print(result) # "DELETE 100" 443 """ 444 await self._ensure_initialized() 445 async with self.pool.acquire() as conn: 446 return await conn.execute(query, *params)
PostgreSQL接続とクエリ実行クラス
Azure AD認証またはユーザー名/パスワード認証でPostgreSQLに接続し、 クエリを実行します。接続プールを使用して効率的に接続を管理します。
Features: - Azure AD認証対応(トークン自動更新) - 標準認証(ユーザー名/パスワード)対応 - 接続プール管理 - 非同期コンテキストマネージャー対応 - SQL Injection防止(パラメータ化クエリ)
Example:
from agenticstar_platform import PostgreSQLManager, PostgreSQLConfig
設定を作成
config = PostgreSQLConfig.from_toml("config.toml")
コンテキストマネージャーで使用(推奨)
async with PostgreSQLManager(config) as db: ... # SELECT実行 ... result = await db.fetch_all("SELECT * FROM users WHERE status = $1", ("active",)) ... print(f"Active users: {len(result)}") ... ... # INSERT実行 ... await db.execute( ... "INSERT INTO logs (message) VALUES ($1)", ... ("Operation completed",) ... )
明示的な初期化・クローズ
db = PostgreSQLManager(config) await db.initialize() try: ... result = await db.fetch_one("SELECT COUNT(*) as cnt FROM users") ... print(f"Total users: {result['cnt']}") ... finally: ... await db.close()
204 async def initialize(self) -> None: 205 """データベース接続プールを初期化""" 206 logger.info("[IN] initialize: Initializing PostgreSQL connection pool") 207 try: 208 # トークンリフレッシュ用のLockを作成(イベントループ内で作成することで紐付けを保証) 209 if self._token_refresh_lock is None: 210 self._token_refresh_lock = asyncio.Lock() 211 212 logger.info("[IN] initialize.create_connection_params: Creating connection parameters") 213 connection_params = await self._create_connection_params() 214 logger.info("[OUT] initialize.create_connection_params: Connection parameters created") 215 216 logger.info( 217 "[IN] initialize.create_pool: Creating connection pool - min_size=%s, max_size=%s", 218 self.config.pool_min_size, 219 self.config.pool_max_size, 220 ) 221 self.pool = await asyncpg.create_pool( 222 min_size=self.config.pool_min_size, 223 max_size=self.config.pool_max_size, 224 **connection_params, 225 ) 226 logger.info("[OUT] initialize.create_pool: Connection pool created successfully") 227 logger.info("[OUT] initialize: PostgreSQL connection pool initialized successfully") 228 229 except Exception as e: 230 logger.error( 231 "[ERROR] initialize: PostgreSQL initialization failed - error_type=%s, error=%s", 232 type(e).__name__, 233 str(e), 234 ) 235 raise
データベース接続プールを初期化
237 async def close(self) -> None: 238 """接続プールを閉じる 239 240 Note: 241 例外が発生してもリソースを確実にクリーンアップします。 242 """ 243 logger.info("[IN] close: Closing PostgreSQL connection pool") 244 try: 245 if self.pool: 246 await self.pool.close() 247 logger.info("[OUT] close: PostgreSQL connection pool closed successfully") 248 else: 249 logger.info("[OUT] close: No connection pool to close") 250 except Exception as e: 251 logger.warning(f"Error closing PostgreSQL connection pool: {e}") 252 finally: 253 self.pool = None 254 self._access_token = None 255 self._token_expiry = None
接続プールを閉じる
Note: 例外が発生してもリソースを確実にクリーンアップします。
257 async def execute_query( 258 self, query: str, params: tuple = () 259 ) -> Dict[str, Any]: 260 """ 261 汎用クエリ実行メソッド 262 263 SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 264 結果を統一フォーマットで返します。 265 266 Args: 267 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 268 params: クエリパラメータ(タプル) 269 270 Returns: 271 Dict[str, Any]: 実行結果 272 - success: bool - 実行成功/失敗 273 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) 274 - error: Optional[str] - エラーメッセージ(失敗時のみ) 275 - error_code: Optional[str] - エラーコード(失敗時のみ) 276 277 Example: 278 >>> # SELECTクエリ 279 >>> result = await db.execute_query( 280 ... "SELECT * FROM users WHERE created_at > $1", 281 ... (datetime(2024, 1, 1),) 282 ... ) 283 >>> if result["success"]: 284 ... for user in result["data"]: 285 ... print(user["name"]) 286 >>> 287 >>> # INSERT with RETURNING 288 >>> result = await db.execute_query( 289 ... "INSERT INTO users (name) VALUES ($1) RETURNING id", 290 ... ("John",) 291 ... ) 292 >>> if result["success"]: 293 ... new_id = result["data"][0]["id"] 294 """ 295 logger.info( 296 "[IN] execute_query: Executing SQL query - query_type=%s, params_count=%s", 297 (query.strip().upper().split()[0] if query.strip() else "EMPTY"), 298 len(params), 299 ) 300 try: 301 await self._ensure_initialized() 302 303 async with self.pool.acquire() as conn: 304 # クエリの種類を判定 305 normalized_query = query.strip().upper() 306 is_select = normalized_query.startswith("SELECT") 307 has_returning = "RETURNING" in normalized_query 308 309 if is_select or has_returning: 310 # SELECT文またはRETURNING句がある場合はfetchを使用 311 rows = await conn.fetch(query, *params) 312 # Rowオブジェクトを辞書に変換 313 result_data = [dict(row) for row in rows] 314 logger.info( 315 "[OUT] execute_query: Query executed successfully - result_count=%s", 316 len(result_data), 317 ) 318 return {"success": True, "data": result_data} 319 else: 320 # INSERT/UPDATE/DELETE等(RETURNING句なし)の場合はexecuteを使用 321 result = await conn.execute(query, *params) 322 logger.info("[OUT] execute_query: Query executed successfully") 323 return {"success": True, "data": {"result": result}} 324 325 except Exception as e: 326 logger.error( 327 "[ERROR] execute_query: Query execution failed - error_type=%s, error=%s", 328 type(e).__name__, 329 str(e), 330 ) 331 return { 332 "success": False, 333 "data": None, 334 "error": f"Database query error: {str(e)}", 335 "error_code": "DB_QUERY_ERROR", 336 }
汎用クエリ実行メソッド
SELECT/INSERT/UPDATE/DELETE等のSQLクエリを実行し、 結果を統一フォーマットで返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: Dict[str, Any]: 実行結果 - success: bool - 実行成功/失敗 - data: List[Dict] (SELECT時) or Dict (INSERT/UPDATE/DELETE時) - error: Optional[str] - エラーメッセージ(失敗時のみ) - error_code: Optional[str] - エラーコード(失敗時のみ)
Example:
SELECTクエリ
result = await db.execute_query( ... "SELECT * FROM users WHERE created_at > $1", ... (datetime(2024, 1, 1),) ... ) if result["success"]: ... for user in result["data"]: ... print(user["name"])
INSERT with RETURNING
result = await db.execute_query( ... "INSERT INTO users (name) VALUES ($1) RETURNING id", ... ("John",) ... ) if result["success"]: ... new_id = result["data"][0]["id"]
338 async def fetch_one( 339 self, query: str, params: tuple = () 340 ) -> Optional[Dict[str, Any]]: 341 """ 342 単一行を取得 343 344 1行だけ返すことが期待されるクエリに使用します。 345 結果が複数行ある場合でも最初の1行のみ返します。 346 347 Args: 348 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 349 params: クエリパラメータ(タプル) 350 351 Returns: 352 Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone 353 354 Example: 355 >>> # ユーザーを1件取得 356 >>> user = await db.fetch_one( 357 ... "SELECT * FROM users WHERE id = $1", 358 ... (user_id,) 359 ... ) 360 >>> if user: 361 ... print(f"Found: {user['name']}") 362 ... else: 363 ... print("User not found") 364 >>> 365 >>> # 集計結果を取得 366 >>> count = await db.fetch_one("SELECT COUNT(*) as total FROM users") 367 >>> print(f"Total: {count['total']}") 368 """ 369 await self._ensure_initialized() 370 async with self.pool.acquire() as conn: 371 row = await conn.fetchrow(query, *params) 372 return dict(row) if row else None
単一行を取得
1行だけ返すことが期待されるクエリに使用します。 結果が複数行ある場合でも最初の1行のみ返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: Optional[Dict[str, Any]]: 結果行(辞書形式)、結果がない場合はNone
Example:
ユーザーを1件取得
user = await db.fetch_one( ... "SELECT * FROM users WHERE id = $1", ... (user_id,) ... ) if user: ... print(f"Found: {user['name']}") ... else: ... print("User not found")
集計結果を取得
count = await db.fetch_one("SELECT COUNT(*) as total FROM users") print(f"Total: {count['total']}")
374 async def fetch_all( 375 self, query: str, params: tuple = () 376 ) -> List[Dict[str, Any]]: 377 """ 378 全行を取得 379 380 複数行を返すSELECTクエリに使用します。 381 結果がない場合は空のリストを返します。 382 383 Args: 384 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 385 params: クエリパラメータ(タプル) 386 387 Returns: 388 List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式) 389 390 Example: 391 >>> # 全ユーザーを取得 392 >>> users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") 393 >>> for user in users: 394 ... print(f"{user['id']}: {user['name']}") 395 >>> 396 >>> # 条件付きで取得 397 >>> active_users = await db.fetch_all( 398 ... "SELECT * FROM users WHERE status = $1 AND role = $2", 399 ... ("active", "admin") 400 ... ) 401 >>> print(f"Found {len(active_users)} admin users") 402 """ 403 await self._ensure_initialized() 404 async with self.pool.acquire() as conn: 405 rows = await conn.fetch(query, *params) 406 return [dict(row) for row in rows]
全行を取得
複数行を返すSELECTクエリに使用します。 結果がない場合は空のリストを返します。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: List[Dict[str, Any]]: 結果行のリスト(各行は辞書形式)
Example:
全ユーザーを取得
users = await db.fetch_all("SELECT * FROM users ORDER BY created_at DESC") for user in users: ... print(f"{user['id']}: {user['name']}")
条件付きで取得
active_users = await db.fetch_all( ... "SELECT * FROM users WHERE status = $1 AND role = $2", ... ("active", "admin") ... ) print(f"Found {len(active_users)} admin users")
408 async def execute(self, query: str, params: tuple = ()) -> str: 409 """ 410 実行のみ(結果を返さないクエリ用) 411 412 INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 413 DDL文(CREATE TABLE等)の実行にも使用できます。 414 415 Args: 416 query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 417 params: クエリパラメータ(タプル) 418 419 Returns: 420 str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3") 421 422 Example: 423 >>> # データ挿入 424 >>> result = await db.execute( 425 ... "INSERT INTO logs (level, message) VALUES ($1, $2)", 426 ... ("INFO", "Operation completed") 427 ... ) 428 >>> print(result) # "INSERT 0 1" 429 >>> 430 >>> # データ更新 431 >>> result = await db.execute( 432 ... "UPDATE users SET status = $1 WHERE last_login < $2", 433 ... ("inactive", datetime(2024, 1, 1)) 434 ... ) 435 >>> print(result) # "UPDATE 15" 436 >>> 437 >>> # データ削除 438 >>> result = await db.execute( 439 ... "DELETE FROM sessions WHERE expired_at < $1", 440 ... (datetime.now(),) 441 ... ) 442 >>> print(result) # "DELETE 100" 443 """ 444 await self._ensure_initialized() 445 async with self.pool.acquire() as conn: 446 return await conn.execute(query, *params)
実行のみ(結果を返さないクエリ用)
INSERT/UPDATE/DELETE等の結果データが不要なクエリに使用します。 DDL文(CREATE TABLE等)の実行にも使用できます。
Args: query: 実行するSQLクエリ。プレースホルダーは $1, $2 形式 params: クエリパラメータ(タプル)
Returns: str: 実行結果のステータス(例: "INSERT 0 1", "UPDATE 5", "DELETE 3")
Example:
データ挿入
result = await db.execute( ... "INSERT INTO logs (level, message) VALUES ($1, $2)", ... ("INFO", "Operation completed") ... ) print(result) # "INSERT 0 1"
データ更新
result = await db.execute( ... "UPDATE users SET status = $1 WHERE last_login < $2", ... ("inactive", datetime(2024, 1, 1)) ... ) print(result) # "UPDATE 15"
データ削除
result = await db.execute( ... "DELETE FROM sessions WHERE expired_at < $1", ... (datetime.now(),) ... ) print(result) # "DELETE 100"
101@dataclass 102class PostgreSQLConfig: 103 """PostgreSQL設定 104 105 Note: 106 passwordはrepr=Falseでログ出力から除外されます。 107 108 Example: 109 >>> # 辞書から作成 110 >>> config = PostgreSQLConfig.from_dict({ 111 ... "host": "localhost", 112 ... "port": 5432, 113 ... "database": "mydb", 114 ... "username": "user", 115 ... "password": "pass" 116 ... }) 117 118 >>> # TOMLファイルから作成 119 >>> config = PostgreSQLConfig.from_toml("config.toml", section="database") 120 """ 121 122 provider: str = "postgresql" 123 host: str = "localhost" 124 port: int = 5432 125 database: str = "agenticstar_db" 126 use_azure_ad: bool = False 127 128 # 標準認証用(use_azure_ad=Falseの場合) 129 username: Optional[str] = None 130 password: Optional[str] = field(default=None, repr=False) # シークレット: reprから除外 131 132 # 接続プール設定 133 pool_min_size: int = 5 134 pool_max_size: int = 20 135 command_timeout: int = 60 136 pool_timeout: Optional[int] = 30 137 max_overflow: Optional[int] = 10 138 139 # Azure AD認証設定 140 azure_ad: Optional[AzureADConfig] = None 141 142 # API Proxy設定(CLIモード用) 143 api_proxy_url: Optional[str] = None 144 145 # SSL設定(デフォルト: require, ローカル開発用: disable) 146 ssl_mode: str = "require" 147 148 def __str__(self) -> str: 149 return ( 150 f"PostgreSQLConfig(host={self.host!r}, port={self.port}, " 151 f"database={self.database!r}, use_azure_ad={self.use_azure_ad}, password=***)" 152 ) 153 154 @classmethod 155 def from_dict(cls, data: Dict[str, Any]) -> "PostgreSQLConfig": 156 """辞書からPostgreSQLConfigを作成 157 158 Args: 159 data: 設定辞書 160 161 Returns: 162 PostgreSQLConfig instance 163 164 Example: 165 >>> config = PostgreSQLConfig.from_dict({ 166 ... "host": "db.example.com", 167 ... "port": 5432, 168 ... "database": "mydb", 169 ... "username": "user", 170 ... "password": "secret", 171 ... "use_azure_ad": False, 172 ... "pool_max_size": 30 173 ... }) 174 """ 175 # Azure AD設定がある場合はネストして処理 176 azure_ad = None 177 if "azure_ad" in data and data.get("use_azure_ad"): 178 azure_ad = AzureADConfig.from_dict(data["azure_ad"]) 179 180 return cls( 181 provider=data.get("provider", "postgresql"), 182 host=data.get("host", "localhost"), 183 port=data.get("port", 5432), 184 database=data.get("database", "agenticstar_db"), 185 use_azure_ad=data.get("use_azure_ad", False), 186 username=data.get("username"), 187 password=data.get("password"), 188 pool_min_size=data.get("pool_min_size", 5), 189 pool_max_size=data.get("pool_max_size", 20), 190 command_timeout=data.get("command_timeout", 60), 191 pool_timeout=data.get("pool_timeout", 30), 192 max_overflow=data.get("max_overflow", 10), 193 azure_ad=azure_ad, 194 api_proxy_url=data.get("api_proxy_url"), 195 ) 196 197 @classmethod 198 def from_toml(cls, toml_path: str, section: str = "database") -> "PostgreSQLConfig": 199 """TOMLファイルからPostgreSQLConfigを作成 200 201 Args: 202 toml_path: TOMLファイルパス 203 section: セクション名(ドット区切りでネスト対応) 204 205 Returns: 206 PostgreSQLConfig instance 207 208 Example: 209 >>> # config.toml の [database] セクションを読み込み 210 >>> config = PostgreSQLConfig.from_toml("config.toml") 211 212 >>> # ネストされたセクション 213 >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database") 214 """ 215 if not TOML_AVAILABLE: 216 raise ImportError("toml library is required: pip install toml") 217 218 with open(toml_path, "r") as f: 219 toml_data = toml.load(f) 220 221 # ネストされたセクションをドット区切りで辿る 222 data = toml_data 223 for key in section.split("."): 224 data = data[key] 225 226 return cls.from_dict(data) 227 228 @classmethod 229 def from_env(cls, prefix: str = "DB_") -> "PostgreSQLConfig": 230 """環境変数からPostgreSQLConfigを作成 231 232 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 233 234 Args: 235 prefix: 環境変数名のプレフィックス(デフォルト: "DB_") 236 237 Returns: 238 PostgreSQLConfig instance 239 240 Example: 241 >>> # 環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.) 242 >>> config = PostgreSQLConfig.from_env() 243 244 >>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) 245 >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_") 246 247 Environment Variables: 248 {prefix}HOST: データベースホスト(デフォルト: localhost) 249 {prefix}PORT: ポート番号(デフォルト: 5432) 250 {prefix}DATABASE: データベース名(デフォルト: agenticstar_db) 251 {prefix}USER / {prefix}USERNAME: ユーザー名 252 {prefix}PASSWORD: パスワード 253 {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) 254 {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) 255 {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) 256 {prefix}API_PROXY_URL: API Proxy URL(CLIモード用) 257 """ 258 def get_env(key: str, default: str = "") -> str: 259 return os.environ.get(f"{prefix}{key}", default) 260 261 def get_env_int(key: str, default: int) -> int: 262 val = os.environ.get(f"{prefix}{key}") 263 return int(val) if val else default 264 265 def get_env_bool(key: str, default: bool = False) -> bool: 266 val = os.environ.get(f"{prefix}{key}", "").lower() 267 return val in ("true", "1", "yes") if val else default 268 269 # ユーザー名は USER または USERNAME どちらでも可 270 username = get_env("USER") or get_env("USERNAME") 271 272 return cls( 273 host=get_env("HOST", "localhost"), 274 port=get_env_int("PORT", 5432), 275 database=get_env("DATABASE", "agenticstar_db"), 276 username=username or None, 277 password=get_env("PASSWORD") or None, 278 use_azure_ad=get_env_bool("USE_AZURE_AD"), 279 pool_min_size=get_env_int("POOL_MIN_SIZE", 5), 280 pool_max_size=get_env_int("POOL_MAX_SIZE", 20), 281 command_timeout=get_env_int("COMMAND_TIMEOUT", 60), 282 pool_timeout=get_env_int("POOL_TIMEOUT", 30), 283 max_overflow=get_env_int("MAX_OVERFLOW", 10), 284 api_proxy_url=get_env("API_PROXY_URL") or None, 285 )
PostgreSQL設定
Note: passwordはrepr=Falseでログ出力から除外されます。
Example:
辞書から作成
config = PostgreSQLConfig.from_dict({ ... "host": "localhost", ... "port": 5432, ... "database": "mydb", ... "username": "user", ... "password": "pass" ... })
>>> # TOMLファイルから作成 >>> config = PostgreSQLConfig.from_toml("config.toml", section="database")
154 @classmethod 155 def from_dict(cls, data: Dict[str, Any]) -> "PostgreSQLConfig": 156 """辞書からPostgreSQLConfigを作成 157 158 Args: 159 data: 設定辞書 160 161 Returns: 162 PostgreSQLConfig instance 163 164 Example: 165 >>> config = PostgreSQLConfig.from_dict({ 166 ... "host": "db.example.com", 167 ... "port": 5432, 168 ... "database": "mydb", 169 ... "username": "user", 170 ... "password": "secret", 171 ... "use_azure_ad": False, 172 ... "pool_max_size": 30 173 ... }) 174 """ 175 # Azure AD設定がある場合はネストして処理 176 azure_ad = None 177 if "azure_ad" in data and data.get("use_azure_ad"): 178 azure_ad = AzureADConfig.from_dict(data["azure_ad"]) 179 180 return cls( 181 provider=data.get("provider", "postgresql"), 182 host=data.get("host", "localhost"), 183 port=data.get("port", 5432), 184 database=data.get("database", "agenticstar_db"), 185 use_azure_ad=data.get("use_azure_ad", False), 186 username=data.get("username"), 187 password=data.get("password"), 188 pool_min_size=data.get("pool_min_size", 5), 189 pool_max_size=data.get("pool_max_size", 20), 190 command_timeout=data.get("command_timeout", 60), 191 pool_timeout=data.get("pool_timeout", 30), 192 max_overflow=data.get("max_overflow", 10), 193 azure_ad=azure_ad, 194 api_proxy_url=data.get("api_proxy_url"), 195 )
辞書からPostgreSQLConfigを作成
Args: data: 設定辞書
Returns: PostgreSQLConfig instance
Example:
config = PostgreSQLConfig.from_dict({ ... "host": "db.example.com", ... "port": 5432, ... "database": "mydb", ... "username": "user", ... "password": "secret", ... "use_azure_ad": False, ... "pool_max_size": 30 ... })
197 @classmethod 198 def from_toml(cls, toml_path: str, section: str = "database") -> "PostgreSQLConfig": 199 """TOMLファイルからPostgreSQLConfigを作成 200 201 Args: 202 toml_path: TOMLファイルパス 203 section: セクション名(ドット区切りでネスト対応) 204 205 Returns: 206 PostgreSQLConfig instance 207 208 Example: 209 >>> # config.toml の [database] セクションを読み込み 210 >>> config = PostgreSQLConfig.from_toml("config.toml") 211 212 >>> # ネストされたセクション 213 >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database") 214 """ 215 if not TOML_AVAILABLE: 216 raise ImportError("toml library is required: pip install toml") 217 218 with open(toml_path, "r") as f: 219 toml_data = toml.load(f) 220 221 # ネストされたセクションをドット区切りで辿る 222 data = toml_data 223 for key in section.split("."): 224 data = data[key] 225 226 return cls.from_dict(data)
TOMLファイルからPostgreSQLConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: PostgreSQLConfig instance
Example:
config.toml の [database] セクションを読み込み
config = PostgreSQLConfig.from_toml("config.toml")
>>> # ネストされたセクション >>> config = PostgreSQLConfig.from_toml("config.toml", section="services.database")
228 @classmethod 229 def from_env(cls, prefix: str = "DB_") -> "PostgreSQLConfig": 230 """環境変数からPostgreSQLConfigを作成 231 232 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 233 234 Args: 235 prefix: 環境変数名のプレフィックス(デフォルト: "DB_") 236 237 Returns: 238 PostgreSQLConfig instance 239 240 Example: 241 >>> # 環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.) 242 >>> config = PostgreSQLConfig.from_env() 243 244 >>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) 245 >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_") 246 247 Environment Variables: 248 {prefix}HOST: データベースホスト(デフォルト: localhost) 249 {prefix}PORT: ポート番号(デフォルト: 5432) 250 {prefix}DATABASE: データベース名(デフォルト: agenticstar_db) 251 {prefix}USER / {prefix}USERNAME: ユーザー名 252 {prefix}PASSWORD: パスワード 253 {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) 254 {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) 255 {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) 256 {prefix}API_PROXY_URL: API Proxy URL(CLIモード用) 257 """ 258 def get_env(key: str, default: str = "") -> str: 259 return os.environ.get(f"{prefix}{key}", default) 260 261 def get_env_int(key: str, default: int) -> int: 262 val = os.environ.get(f"{prefix}{key}") 263 return int(val) if val else default 264 265 def get_env_bool(key: str, default: bool = False) -> bool: 266 val = os.environ.get(f"{prefix}{key}", "").lower() 267 return val in ("true", "1", "yes") if val else default 268 269 # ユーザー名は USER または USERNAME どちらでも可 270 username = get_env("USER") or get_env("USERNAME") 271 272 return cls( 273 host=get_env("HOST", "localhost"), 274 port=get_env_int("PORT", 5432), 275 database=get_env("DATABASE", "agenticstar_db"), 276 username=username or None, 277 password=get_env("PASSWORD") or None, 278 use_azure_ad=get_env_bool("USE_AZURE_AD"), 279 pool_min_size=get_env_int("POOL_MIN_SIZE", 5), 280 pool_max_size=get_env_int("POOL_MAX_SIZE", 20), 281 command_timeout=get_env_int("COMMAND_TIMEOUT", 60), 282 pool_timeout=get_env_int("POOL_TIMEOUT", 30), 283 max_overflow=get_env_int("MAX_OVERFLOW", 10), 284 api_proxy_url=get_env("API_PROXY_URL") or None, 285 )
環境変数からPostgreSQLConfigを作成
環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。
Args: prefix: 環境変数名のプレフィックス(デフォルト: "DB_")
Returns: PostgreSQLConfig instance
Example:
環境変数から作成(DB_HOST, DB_PORT, DB_DATABASE, etc.)
config = PostgreSQLConfig.from_env()
>>> # カスタムプレフィックス(POSTGRES_HOST, POSTGRES_PORT, etc.) >>> config = PostgreSQLConfig.from_env(prefix="POSTGRES_")Environment Variables: {prefix}HOST: データベースホスト(デフォルト: localhost) {prefix}PORT: ポート番号(デフォルト: 5432) {prefix}DATABASE: データベース名(デフォルト: agenticstar_db) {prefix}USER / {prefix}USERNAME: ユーザー名 {prefix}PASSWORD: パスワード {prefix}USE_AZURE_AD: Azure AD認証を使用するか(true/false) {prefix}POOL_MIN_SIZE: 最小プールサイズ(デフォルト: 5) {prefix}POOL_MAX_SIZE: 最大プールサイズ(デフォルト: 20) {prefix}API_PROXY_URL: API Proxy URL(CLIモード用)
21@dataclass 22class AzureADConfig: 23 """Azure AD認証設定 24 25 Note: 26 client_secretはrepr=Falseでログ出力から除外されます。 27 28 Example: 29 >>> # 辞書から作成 30 >>> config = AzureADConfig.from_dict({ 31 ... "tenant_id": "xxx", 32 ... "client_id": "yyy", 33 ... "client_secret": "zzz" 34 ... }) 35 36 >>> # TOMLファイルから作成 37 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 38 """ 39 40 tenant_id: str 41 client_id: str 42 client_secret: str = field(repr=False) # シークレット: reprから除外 43 username: str = "" # PostgreSQL Azure AD authentication username 44 45 def __str__(self) -> str: 46 return f"AzureADConfig(tenant_id={self.tenant_id!r}, client_id={self.client_id!r}, client_secret=***)" 47 48 @classmethod 49 def from_dict(cls, data: Dict[str, Any]) -> "AzureADConfig": 50 """辞書からAzureADConfigを作成 51 52 Args: 53 data: 設定辞書 54 55 Returns: 56 AzureADConfig instance 57 58 Example: 59 >>> config = AzureADConfig.from_dict({ 60 ... "tenant_id": "your-tenant-id", 61 ... "client_id": "your-client-id", 62 ... "client_secret": "your-secret" 63 ... }) 64 """ 65 return cls( 66 tenant_id=data["tenant_id"], 67 client_id=data["client_id"], 68 client_secret=data["client_secret"], 69 username=data.get("username", ""), 70 ) 71 72 @classmethod 73 def from_toml(cls, toml_path: str, section: str = "azure_ad") -> "AzureADConfig": 74 """TOMLファイルからAzureADConfigを作成 75 76 Args: 77 toml_path: TOMLファイルパス 78 section: セクション名(ドット区切りでネスト対応) 79 80 Returns: 81 AzureADConfig instance 82 83 Example: 84 >>> # config.toml の [database.azure_ad] セクションを読み込み 85 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 86 """ 87 if not TOML_AVAILABLE: 88 raise ImportError("toml library is required: pip install toml") 89 90 with open(toml_path, "r") as f: 91 toml_data = toml.load(f) 92 93 # ネストされたセクションをドット区切りで辿る 94 data = toml_data 95 for key in section.split("."): 96 data = data[key] 97 98 return cls.from_dict(data)
Azure AD認証設定
Note: client_secretはrepr=Falseでログ出力から除外されます。
Example:
辞書から作成
config = AzureADConfig.from_dict({ ... "tenant_id": "xxx", ... "client_id": "yyy", ... "client_secret": "zzz" ... })
>>> # TOMLファイルから作成 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad")
48 @classmethod 49 def from_dict(cls, data: Dict[str, Any]) -> "AzureADConfig": 50 """辞書からAzureADConfigを作成 51 52 Args: 53 data: 設定辞書 54 55 Returns: 56 AzureADConfig instance 57 58 Example: 59 >>> config = AzureADConfig.from_dict({ 60 ... "tenant_id": "your-tenant-id", 61 ... "client_id": "your-client-id", 62 ... "client_secret": "your-secret" 63 ... }) 64 """ 65 return cls( 66 tenant_id=data["tenant_id"], 67 client_id=data["client_id"], 68 client_secret=data["client_secret"], 69 username=data.get("username", ""), 70 )
辞書からAzureADConfigを作成
Args: data: 設定辞書
Returns: AzureADConfig instance
Example:
config = AzureADConfig.from_dict({ ... "tenant_id": "your-tenant-id", ... "client_id": "your-client-id", ... "client_secret": "your-secret" ... })
72 @classmethod 73 def from_toml(cls, toml_path: str, section: str = "azure_ad") -> "AzureADConfig": 74 """TOMLファイルからAzureADConfigを作成 75 76 Args: 77 toml_path: TOMLファイルパス 78 section: セクション名(ドット区切りでネスト対応) 79 80 Returns: 81 AzureADConfig instance 82 83 Example: 84 >>> # config.toml の [database.azure_ad] セクションを読み込み 85 >>> config = AzureADConfig.from_toml("config.toml", section="database.azure_ad") 86 """ 87 if not TOML_AVAILABLE: 88 raise ImportError("toml library is required: pip install toml") 89 90 with open(toml_path, "r") as f: 91 toml_data = toml.load(f) 92 93 # ネストされたセクションをドット区切りで辿る 94 data = toml_data 95 for key in section.split("."): 96 data = data[key] 97 98 return cls.from_dict(data)
TOMLファイルからAzureADConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: AzureADConfig instance
Example:
config.toml の [database.azure_ad] セクションを読み込み
config = AzureADConfig.from_toml("config.toml", section="database.azure_ad")
45class DataAccess: 46 """ 47 汎用データアクセスクラス 48 49 テーブル操作の基本機能を提供: 50 - insert: 単一レコード挿入 51 - upsert: 挿入または更新 52 - select: 条件付き検索 53 - update: 条件付き更新 54 - delete: 条件付き削除 55 - execute_query: 生SQLクエリ実行 56 57 Security: 58 全てのテーブル名・カラム名は _validate_identifier() で検証されます。 59 不正な識別子が渡された場合は SQLInjectionError が発生します。 60 61 Example: 62 async with DataAccess(config) as da: 63 result = await da.select("users", where={"id": 1}) 64 """ 65 66 def __init__( 67 self, 68 config: PostgreSQLConfig, 69 *, 70 use_proxy: bool = False, 71 token_provider: Optional[Callable[[], Optional[str]]] = None, 72 ): 73 """ 74 Args: 75 config: PostgreSQL設定(位置引数、後方互換) 76 use_proxy: API Proxy経由でアクセスするか(デフォルト: False) 77 token_provider: Proxy使用時の認証トークン取得関数 78 79 Note: 80 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 81 既存コード(config引数のみ)は引き続き動作する。 82 """ 83 self.db = create_postgresql_manager( 84 config, 85 use_proxy=use_proxy, 86 token_provider=token_provider, 87 ) 88 self.logger = logger 89 90 async def __aenter__(self) -> "DataAccess": 91 """Context Manager: 開始時にデータベース接続を初期化""" 92 await self.initialize() 93 return self 94 95 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 96 """Context Manager: 終了時にデータベース接続をクローズ""" 97 await self.close() 98 99 def _validate_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: 100 """ 101 SQL識別子(テーブル名・カラム名)を検証 102 103 Args: 104 identifier: 検証する識別子 105 identifier_type: エラーメッセージ用の識別子タイプ 106 107 Returns: 108 検証済みの識別子 109 110 Raises: 111 SQLInjectionError: 不正な識別子の場合 112 """ 113 if not identifier or not _IDENTIFIER_PATTERN.match(identifier): 114 raise SQLInjectionError(identifier, identifier_type) 115 return identifier 116 117 def _validate_identifiers(self, identifiers: List[str], identifier_type: str = "column") -> List[str]: 118 """ 119 複数のSQL識別子を検証 120 121 Args: 122 identifiers: 検証する識別子のリスト 123 identifier_type: エラーメッセージ用の識別子タイプ 124 125 Returns: 126 検証済みの識別子リスト 127 128 Raises: 129 SQLInjectionError: 不正な識別子が含まれる場合 130 """ 131 return [self._validate_identifier(ident, identifier_type) for ident in identifiers] 132 133 def _validate_order_by(self, order_by: str) -> str: 134 """ 135 ORDER BY句を検証 136 137 Args: 138 order_by: ORDER BY句(例: "created_at DESC") 139 140 Returns: 141 検証済みのORDER BY句 142 143 Raises: 144 SQLInjectionError: 不正なORDER BY句の場合 145 """ 146 if not order_by or not _ORDER_BY_PATTERN.match(order_by.strip()): 147 raise SQLInjectionError(order_by, "ORDER BY clause") 148 return order_by 149 150 async def initialize(self) -> None: 151 """データベース接続を初期化""" 152 await self.db.initialize() 153 154 def is_initialized(self) -> bool: 155 """データベース接続が初期化されているか確認 156 157 Note: 158 PostgreSQLManager: pool属性で判定 159 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続) 160 """ 161 if hasattr(self.db, 'pool'): 162 return self.db.pool is not None 163 # ApiPostgreSQLManager等、pool属性を持たないマネージャは常に初期化済み扱い 164 return True 165 166 async def ensure_initialized(self) -> None: 167 """データベース接続が初期化されていなければ初期化""" 168 if not self.is_initialized(): 169 await self.initialize() 170 171 async def close(self) -> None: 172 """データベース接続を閉じる""" 173 await self.db.close() 174 175 async def execute_query( 176 self, query: str, params: tuple = () 177 ) -> Dict[str, Any]: 178 """ 179 生SQLクエリを実行 180 181 Args: 182 query: SQLクエリ 183 params: クエリパラメータ 184 185 Returns: 186 Dict with keys: success, data, error (optional) 187 """ 188 return await self.db.execute_query(query, params) 189 190 async def insert( 191 self, 192 table: str, 193 data: Dict[str, Any], 194 returning: Optional[List[str]] = None, 195 ) -> Dict[str, Any]: 196 """ 197 単一レコードを挿入 198 199 Args: 200 table: テーブル名 201 data: 挿入するデータ {column: value} 202 returning: 返却するカラム名のリスト 203 204 Returns: 205 Dict with keys: success, data, error (optional) 206 207 Raises: 208 SQLInjectionError: 不正なテーブル名・カラム名の場合 209 """ 210 # SQL Injection対策: 識別子を検証 211 safe_table = self._validate_identifier(table, "table") 212 columns = list(data.keys()) 213 safe_columns = self._validate_identifiers(columns, "column") 214 215 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 216 values = [self._serialize_value(v) for v in data.values()] 217 218 query = f"INSERT INTO {safe_table} ({', '.join(safe_columns)}) VALUES ({', '.join(placeholders)})" 219 220 if returning: 221 safe_returning = self._validate_identifiers(returning, "returning column") 222 query += f" RETURNING {', '.join(safe_returning)}" 223 224 return await self.db.execute_query(query, tuple(values)) 225 226 async def upsert( 227 self, 228 table: str, 229 data: Dict[str, Any], 230 conflict_columns: List[str], 231 update_columns: Optional[List[str]] = None, 232 returning: Optional[List[str]] = None, 233 ) -> Dict[str, Any]: 234 """ 235 挿入または更新(UPSERT) 236 237 Args: 238 table: テーブル名 239 data: 挿入/更新するデータ {column: value} 240 conflict_columns: 競合判定に使うカラム 241 update_columns: 競合時に更新するカラム(Noneの場合は全カラム) 242 returning: 返却するカラム名のリスト 243 244 Returns: 245 Dict with keys: success, data, error (optional) 246 247 Raises: 248 SQLInjectionError: 不正なテーブル名・カラム名の場合 249 """ 250 # SQL Injection対策: 識別子を検証 251 safe_table = self._validate_identifier(table, "table") 252 columns = list(data.keys()) 253 safe_columns = self._validate_identifiers(columns, "column") 254 safe_conflict_columns = self._validate_identifiers(conflict_columns, "conflict column") 255 256 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 257 values = [self._serialize_value(v) for v in data.values()] 258 259 # 更新対象カラムを決定 260 if update_columns is None: 261 update_columns = [c for c in columns if c not in conflict_columns] 262 safe_update_columns = self._validate_identifiers(update_columns, "update column") 263 264 # UPDATE SET句を構築 265 update_set = ", ".join( 266 [f"{col} = EXCLUDED.{col}" for col in safe_update_columns] 267 ) 268 269 query = f""" 270 INSERT INTO {safe_table} ({', '.join(safe_columns)}) 271 VALUES ({', '.join(placeholders)}) 272 ON CONFLICT ({', '.join(safe_conflict_columns)}) 273 DO UPDATE SET {update_set} 274 """ 275 276 if returning: 277 safe_returning = self._validate_identifiers(returning, "returning column") 278 query += f" RETURNING {', '.join(safe_returning)}" 279 280 return await self.db.execute_query(query, tuple(values)) 281 282 async def select( 283 self, 284 table: str, 285 columns: Optional[List[str]] = None, 286 where: Optional[Dict[str, Any]] = None, 287 order_by: Optional[str] = None, 288 limit: Optional[int] = None, 289 ) -> Dict[str, Any]: 290 """ 291 条件付き検索 292 293 Args: 294 table: テーブル名 295 columns: 取得するカラム(Noneの場合は全カラム) 296 where: 検索条件 {column: value} 297 order_by: ソート指定(例: "created_at DESC") 298 limit: 取得件数上限 299 300 Returns: 301 Dict with keys: success, data (List[Dict]), error (optional) 302 303 Raises: 304 SQLInjectionError: 不正なテーブル名・カラム名の場合 305 """ 306 # SQL Injection対策: 識別子を検証 307 safe_table = self._validate_identifier(table, "table") 308 309 if columns is None: 310 select_cols = "*" 311 else: 312 safe_columns = self._validate_identifiers(columns, "column") 313 select_cols = ", ".join(safe_columns) 314 315 query = f"SELECT {select_cols} FROM {safe_table}" 316 params: List[Any] = [] 317 318 if where: 319 conditions = [] 320 param_index = 1 321 for col, val in where.items(): 322 safe_col = self._validate_identifier(col, "where column") 323 if val is None: 324 conditions.append(f"{safe_col} IS NULL") 325 else: 326 conditions.append(f"{safe_col} = ${param_index}") 327 params.append(self._serialize_value(val)) 328 param_index += 1 329 query += f" WHERE {' AND '.join(conditions)}" 330 331 if order_by: 332 safe_order_by = self._validate_order_by(order_by) 333 query += f" ORDER BY {safe_order_by}" 334 335 if limit is not None: 336 if not isinstance(limit, int) or limit < 0: 337 raise ValueError(f"Invalid limit value: {limit}") 338 query += f" LIMIT {limit}" 339 340 return await self.db.execute_query(query, tuple(params)) 341 342 async def select_one( 343 self, 344 table: str, 345 columns: Optional[List[str]] = None, 346 where: Optional[Dict[str, Any]] = None, 347 ) -> Optional[Dict[str, Any]]: 348 """ 349 単一レコードを取得 350 351 Args: 352 table: テーブル名 353 columns: 取得するカラム 354 where: 検索条件 355 356 Returns: 357 Dict (結果がある場合) or None 358 """ 359 result = await self.select(table, columns, where, limit=1) 360 if result.get("success") and result.get("data"): 361 return result["data"][0] 362 return None 363 364 async def update( 365 self, 366 table: str, 367 data: Dict[str, Any], 368 where: Dict[str, Any], 369 returning: Optional[List[str]] = None, 370 ) -> Dict[str, Any]: 371 """ 372 条件付き更新 373 374 Args: 375 table: テーブル名 376 data: 更新するデータ {column: value} 377 where: 更新条件 {column: value} 378 returning: 返却するカラム名のリスト 379 380 Returns: 381 Dict with keys: success, data, error (optional) 382 383 Raises: 384 SQLInjectionError: 不正なテーブル名・カラム名の場合 385 """ 386 # SQL Injection対策: 識別子を検証 387 safe_table = self._validate_identifier(table, "table") 388 389 # SET句を構築 390 set_columns = list(data.keys()) 391 safe_set_columns = self._validate_identifiers(set_columns, "set column") 392 set_clause = ", ".join([f"{col} = ${i+1}" for i, col in enumerate(safe_set_columns)]) 393 params: List[Any] = [self._serialize_value(v) for v in data.values()] 394 395 # WHERE句を構築 396 where_columns = list(where.keys()) 397 safe_where_columns = self._validate_identifiers(where_columns, "where column") 398 where_clause = " AND ".join( 399 [f"{col} = ${len(safe_set_columns) + i + 1}" for i, col in enumerate(safe_where_columns)] 400 ) 401 params.extend([self._serialize_value(v) for v in where.values()]) 402 403 query = f"UPDATE {safe_table} SET {set_clause} WHERE {where_clause}" 404 405 if returning: 406 safe_returning = self._validate_identifiers(returning, "returning column") 407 query += f" RETURNING {', '.join(safe_returning)}" 408 409 return await self.db.execute_query(query, tuple(params)) 410 411 async def delete( 412 self, 413 table: str, 414 where: Dict[str, Any], 415 returning: Optional[List[str]] = None, 416 ) -> Dict[str, Any]: 417 """ 418 条件付き削除 419 420 Args: 421 table: テーブル名 422 where: 削除条件 {column: value} 423 returning: 返却するカラム名のリスト 424 425 Returns: 426 Dict with keys: success, data, error (optional) 427 428 Raises: 429 SQLInjectionError: 不正なテーブル名・カラム名の場合 430 """ 431 # SQL Injection対策: 識別子を検証 432 safe_table = self._validate_identifier(table, "table") 433 434 # WHERE句を構築 435 conditions = [] 436 params: List[Any] = [] 437 param_index = 1 438 for col, val in where.items(): 439 safe_col = self._validate_identifier(col, "where column") 440 if val is None: 441 conditions.append(f"{safe_col} IS NULL") 442 else: 443 conditions.append(f"{safe_col} = ${param_index}") 444 params.append(self._serialize_value(val)) 445 param_index += 1 446 447 query = f"DELETE FROM {safe_table} WHERE {' AND '.join(conditions)}" 448 449 if returning: 450 safe_returning = self._validate_identifiers(returning, "returning column") 451 query += f" RETURNING {', '.join(safe_returning)}" 452 453 return await self.db.execute_query(query, tuple(params)) 454 455 def _serialize_value(self, value: Any) -> Any: 456 """ 457 値をDB保存用にシリアライズ 458 459 Args: 460 value: シリアライズする値 461 462 Returns: 463 シリアライズされた値 464 """ 465 if isinstance(value, dict) or isinstance(value, list): 466 return json.dumps(value, ensure_ascii=False) 467 return value
汎用データアクセスクラス
テーブル操作の基本機能を提供:
- insert: 単一レコード挿入
- upsert: 挿入または更新
- select: 条件付き検索
- update: 条件付き更新
- delete: 条件付き削除
- execute_query: 生SQLクエリ実行
Security: 全てのテーブル名・カラム名は _validate_identifier() で検証されます。 不正な識別子が渡された場合は SQLInjectionError が発生します。
Example: async with DataAccess(config) as da: result = await da.select("users", where={"id": 1})
66 def __init__( 67 self, 68 config: PostgreSQLConfig, 69 *, 70 use_proxy: bool = False, 71 token_provider: Optional[Callable[[], Optional[str]]] = None, 72 ): 73 """ 74 Args: 75 config: PostgreSQL設定(位置引数、後方互換) 76 use_proxy: API Proxy経由でアクセスするか(デフォルト: False) 77 token_provider: Proxy使用時の認証トークン取得関数 78 79 Note: 80 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 81 既存コード(config引数のみ)は引き続き動作する。 82 """ 83 self.db = create_postgresql_manager( 84 config, 85 use_proxy=use_proxy, 86 token_provider=token_provider, 87 ) 88 self.logger = logger
Args: config: PostgreSQL設定(位置引数、後方互換) use_proxy: API Proxy経由でアクセスするか(デフォルト: False) token_provider: Proxy使用時の認証トークン取得関数
Note: 後方互換性のため、use_proxyとtoken_providerはキーワード引数のみ。 既存コード(config引数のみ)は引き続き動作する。
154 def is_initialized(self) -> bool: 155 """データベース接続が初期化されているか確認 156 157 Note: 158 PostgreSQLManager: pool属性で判定 159 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続) 160 """ 161 if hasattr(self.db, 'pool'): 162 return self.db.pool is not None 163 # ApiPostgreSQLManager等、pool属性を持たないマネージャは常に初期化済み扱い 164 return True
データベース接続が初期化されているか確認
Note: PostgreSQLManager: pool属性で判定 ApiPostgreSQLManager: pool属性がないため常にTrue(HTTP経由で都度接続)
166 async def ensure_initialized(self) -> None: 167 """データベース接続が初期化されていなければ初期化""" 168 if not self.is_initialized(): 169 await self.initialize()
データベース接続が初期化されていなければ初期化
175 async def execute_query( 176 self, query: str, params: tuple = () 177 ) -> Dict[str, Any]: 178 """ 179 生SQLクエリを実行 180 181 Args: 182 query: SQLクエリ 183 params: クエリパラメータ 184 185 Returns: 186 Dict with keys: success, data, error (optional) 187 """ 188 return await self.db.execute_query(query, params)
生SQLクエリを実行
Args: query: SQLクエリ params: クエリパラメータ
Returns: Dict with keys: success, data, error (optional)
190 async def insert( 191 self, 192 table: str, 193 data: Dict[str, Any], 194 returning: Optional[List[str]] = None, 195 ) -> Dict[str, Any]: 196 """ 197 単一レコードを挿入 198 199 Args: 200 table: テーブル名 201 data: 挿入するデータ {column: value} 202 returning: 返却するカラム名のリスト 203 204 Returns: 205 Dict with keys: success, data, error (optional) 206 207 Raises: 208 SQLInjectionError: 不正なテーブル名・カラム名の場合 209 """ 210 # SQL Injection対策: 識別子を検証 211 safe_table = self._validate_identifier(table, "table") 212 columns = list(data.keys()) 213 safe_columns = self._validate_identifiers(columns, "column") 214 215 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 216 values = [self._serialize_value(v) for v in data.values()] 217 218 query = f"INSERT INTO {safe_table} ({', '.join(safe_columns)}) VALUES ({', '.join(placeholders)})" 219 220 if returning: 221 safe_returning = self._validate_identifiers(returning, "returning column") 222 query += f" RETURNING {', '.join(safe_returning)}" 223 224 return await self.db.execute_query(query, tuple(values))
単一レコードを挿入
Args: table: テーブル名 data: 挿入するデータ {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
226 async def upsert( 227 self, 228 table: str, 229 data: Dict[str, Any], 230 conflict_columns: List[str], 231 update_columns: Optional[List[str]] = None, 232 returning: Optional[List[str]] = None, 233 ) -> Dict[str, Any]: 234 """ 235 挿入または更新(UPSERT) 236 237 Args: 238 table: テーブル名 239 data: 挿入/更新するデータ {column: value} 240 conflict_columns: 競合判定に使うカラム 241 update_columns: 競合時に更新するカラム(Noneの場合は全カラム) 242 returning: 返却するカラム名のリスト 243 244 Returns: 245 Dict with keys: success, data, error (optional) 246 247 Raises: 248 SQLInjectionError: 不正なテーブル名・カラム名の場合 249 """ 250 # SQL Injection対策: 識別子を検証 251 safe_table = self._validate_identifier(table, "table") 252 columns = list(data.keys()) 253 safe_columns = self._validate_identifiers(columns, "column") 254 safe_conflict_columns = self._validate_identifiers(conflict_columns, "conflict column") 255 256 placeholders = [f"${i+1}" for i in range(len(safe_columns))] 257 values = [self._serialize_value(v) for v in data.values()] 258 259 # 更新対象カラムを決定 260 if update_columns is None: 261 update_columns = [c for c in columns if c not in conflict_columns] 262 safe_update_columns = self._validate_identifiers(update_columns, "update column") 263 264 # UPDATE SET句を構築 265 update_set = ", ".join( 266 [f"{col} = EXCLUDED.{col}" for col in safe_update_columns] 267 ) 268 269 query = f""" 270 INSERT INTO {safe_table} ({', '.join(safe_columns)}) 271 VALUES ({', '.join(placeholders)}) 272 ON CONFLICT ({', '.join(safe_conflict_columns)}) 273 DO UPDATE SET {update_set} 274 """ 275 276 if returning: 277 safe_returning = self._validate_identifiers(returning, "returning column") 278 query += f" RETURNING {', '.join(safe_returning)}" 279 280 return await self.db.execute_query(query, tuple(values))
挿入または更新(UPSERT)
Args: table: テーブル名 data: 挿入/更新するデータ {column: value} conflict_columns: 競合判定に使うカラム update_columns: 競合時に更新するカラム(Noneの場合は全カラム) returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
282 async def select( 283 self, 284 table: str, 285 columns: Optional[List[str]] = None, 286 where: Optional[Dict[str, Any]] = None, 287 order_by: Optional[str] = None, 288 limit: Optional[int] = None, 289 ) -> Dict[str, Any]: 290 """ 291 条件付き検索 292 293 Args: 294 table: テーブル名 295 columns: 取得するカラム(Noneの場合は全カラム) 296 where: 検索条件 {column: value} 297 order_by: ソート指定(例: "created_at DESC") 298 limit: 取得件数上限 299 300 Returns: 301 Dict with keys: success, data (List[Dict]), error (optional) 302 303 Raises: 304 SQLInjectionError: 不正なテーブル名・カラム名の場合 305 """ 306 # SQL Injection対策: 識別子を検証 307 safe_table = self._validate_identifier(table, "table") 308 309 if columns is None: 310 select_cols = "*" 311 else: 312 safe_columns = self._validate_identifiers(columns, "column") 313 select_cols = ", ".join(safe_columns) 314 315 query = f"SELECT {select_cols} FROM {safe_table}" 316 params: List[Any] = [] 317 318 if where: 319 conditions = [] 320 param_index = 1 321 for col, val in where.items(): 322 safe_col = self._validate_identifier(col, "where column") 323 if val is None: 324 conditions.append(f"{safe_col} IS NULL") 325 else: 326 conditions.append(f"{safe_col} = ${param_index}") 327 params.append(self._serialize_value(val)) 328 param_index += 1 329 query += f" WHERE {' AND '.join(conditions)}" 330 331 if order_by: 332 safe_order_by = self._validate_order_by(order_by) 333 query += f" ORDER BY {safe_order_by}" 334 335 if limit is not None: 336 if not isinstance(limit, int) or limit < 0: 337 raise ValueError(f"Invalid limit value: {limit}") 338 query += f" LIMIT {limit}" 339 340 return await self.db.execute_query(query, tuple(params))
条件付き検索
Args: table: テーブル名 columns: 取得するカラム(Noneの場合は全カラム) where: 検索条件 {column: value} order_by: ソート指定(例: "created_at DESC") limit: 取得件数上限
Returns: Dict with keys: success, data (List[Dict]), error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
342 async def select_one( 343 self, 344 table: str, 345 columns: Optional[List[str]] = None, 346 where: Optional[Dict[str, Any]] = None, 347 ) -> Optional[Dict[str, Any]]: 348 """ 349 単一レコードを取得 350 351 Args: 352 table: テーブル名 353 columns: 取得するカラム 354 where: 検索条件 355 356 Returns: 357 Dict (結果がある場合) or None 358 """ 359 result = await self.select(table, columns, where, limit=1) 360 if result.get("success") and result.get("data"): 361 return result["data"][0] 362 return None
単一レコードを取得
Args: table: テーブル名 columns: 取得するカラム where: 検索条件
Returns: Dict (結果がある場合) or None
364 async def update( 365 self, 366 table: str, 367 data: Dict[str, Any], 368 where: Dict[str, Any], 369 returning: Optional[List[str]] = None, 370 ) -> Dict[str, Any]: 371 """ 372 条件付き更新 373 374 Args: 375 table: テーブル名 376 data: 更新するデータ {column: value} 377 where: 更新条件 {column: value} 378 returning: 返却するカラム名のリスト 379 380 Returns: 381 Dict with keys: success, data, error (optional) 382 383 Raises: 384 SQLInjectionError: 不正なテーブル名・カラム名の場合 385 """ 386 # SQL Injection対策: 識別子を検証 387 safe_table = self._validate_identifier(table, "table") 388 389 # SET句を構築 390 set_columns = list(data.keys()) 391 safe_set_columns = self._validate_identifiers(set_columns, "set column") 392 set_clause = ", ".join([f"{col} = ${i+1}" for i, col in enumerate(safe_set_columns)]) 393 params: List[Any] = [self._serialize_value(v) for v in data.values()] 394 395 # WHERE句を構築 396 where_columns = list(where.keys()) 397 safe_where_columns = self._validate_identifiers(where_columns, "where column") 398 where_clause = " AND ".join( 399 [f"{col} = ${len(safe_set_columns) + i + 1}" for i, col in enumerate(safe_where_columns)] 400 ) 401 params.extend([self._serialize_value(v) for v in where.values()]) 402 403 query = f"UPDATE {safe_table} SET {set_clause} WHERE {where_clause}" 404 405 if returning: 406 safe_returning = self._validate_identifiers(returning, "returning column") 407 query += f" RETURNING {', '.join(safe_returning)}" 408 409 return await self.db.execute_query(query, tuple(params))
条件付き更新
Args: table: テーブル名 data: 更新するデータ {column: value} where: 更新条件 {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
411 async def delete( 412 self, 413 table: str, 414 where: Dict[str, Any], 415 returning: Optional[List[str]] = None, 416 ) -> Dict[str, Any]: 417 """ 418 条件付き削除 419 420 Args: 421 table: テーブル名 422 where: 削除条件 {column: value} 423 returning: 返却するカラム名のリスト 424 425 Returns: 426 Dict with keys: success, data, error (optional) 427 428 Raises: 429 SQLInjectionError: 不正なテーブル名・カラム名の場合 430 """ 431 # SQL Injection対策: 識別子を検証 432 safe_table = self._validate_identifier(table, "table") 433 434 # WHERE句を構築 435 conditions = [] 436 params: List[Any] = [] 437 param_index = 1 438 for col, val in where.items(): 439 safe_col = self._validate_identifier(col, "where column") 440 if val is None: 441 conditions.append(f"{safe_col} IS NULL") 442 else: 443 conditions.append(f"{safe_col} = ${param_index}") 444 params.append(self._serialize_value(val)) 445 param_index += 1 446 447 query = f"DELETE FROM {safe_table} WHERE {' AND '.join(conditions)}" 448 449 if returning: 450 safe_returning = self._validate_identifiers(returning, "returning column") 451 query += f" RETURNING {', '.join(safe_returning)}" 452 453 return await self.db.execute_query(query, tuple(params))
条件付き削除
Args: table: テーブル名 where: 削除条件 {column: value} returning: 返却するカラム名のリスト
Returns: Dict with keys: success, data, error (optional)
Raises: SQLInjectionError: 不正なテーブル名・カラム名の場合
23class ConfigAccess: 24 """ 25 設定データアクセスクラス 26 27 各種設定テーブルからの読み込み機能を提供: 28 - エージェント指示・設定 29 - ツール設定 30 - ガードレール設定 31 - システム設定 32 - MCP設定 33 """ 34 35 def __init__(self, db: DataAccess): 36 """ 37 Args: 38 db: DataAccessインスタンス 39 """ 40 self._db = db 41 self.logger = logger 42 43 def _is_cli_mode(self, mode: Optional[str] = None) -> bool: 44 """CLIモードかどうかを判定""" 45 effective_mode = mode or os.environ.get("EXECUTION_MODE", "pod") 46 return effective_mode in CLI_MODES or effective_mode.startswith("cli") 47 48 # ===================================================== 49 # AGENT INSTRUCTIONS 50 # ===================================================== 51 52 async def get_agent_instructions(self, agent_type: str) -> Dict[str, Any]: 53 """ 54 エージェント指示を取得 55 56 Args: 57 agent_type: エージェントタイプ ('intent', 'react', 'coordinator') 58 59 Returns: 60 Dict with keys: success, data (instructions string), error (optional) 61 """ 62 self.logger.info(f"[IN] get_agent_instructions: agent_type={agent_type}") 63 try: 64 result = await self._db.select( 65 table="agent_instructions", 66 columns=["instructions"], 67 where={"agent_type": agent_type}, 68 ) 69 70 # DBエラーを先にチェック 71 if not result.get("success"): 72 self.logger.error(f"[ERROR] get_agent_instructions: DB error for {agent_type}") 73 return { 74 "success": False, 75 "error": result.get("error", "Database query failed"), 76 "error_code": result.get("error_code", "DATABASE_ERROR") 77 } 78 79 # データが存在するかチェック 80 if result.get("data"): 81 instructions = result["data"][0]["instructions"] 82 self.logger.info(f"[OUT] get_agent_instructions: Found for {agent_type}") 83 return {"success": True, "data": instructions} 84 else: 85 self.logger.warning(f"[OUT] get_agent_instructions: Not found for {agent_type}") 86 return { 87 "success": False, 88 "error": f"No instructions found for agent type: {agent_type}", 89 "error_code": "INSTRUCTIONS_NOT_FOUND" 90 } 91 92 except Exception as e: 93 self.logger.error(f"[ERROR] get_agent_instructions: {str(e)}") 94 return { 95 "success": False, 96 "error": f"Database error: {str(e)}", 97 "error_code": "DATABASE_ERROR" 98 } 99 100 # ===================================================== 101 # AGENT CONFIGURATIONS 102 # ===================================================== 103 104 async def get_all_agent_configs(self) -> Dict[str, Any]: 105 """ 106 全エージェント設定を一括取得 107 108 Returns: 109 Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional) 110 """ 111 self.logger.info("[IN] get_all_agent_configs") 112 try: 113 query = """ 114 SELECT agent_type, config, config_level 115 FROM agent_configurations 116 WHERE is_active = TRUE 117 ORDER BY agent_type, config_level 118 """ 119 120 result = await self._db.execute_query(query) 121 122 # DBエラーを先にチェック 123 if not result.get("success"): 124 self.logger.error("[ERROR] get_all_agent_configs: DB error") 125 return { 126 "success": False, 127 "error": result.get("error", "Database query failed"), 128 "error_code": result.get("error_code", "DATABASE_ERROR") 129 } 130 131 configs = {} 132 if result.get("data"): 133 for row in result["data"]: 134 agent_type = row['agent_type'] 135 config_level = row['config_level'] 136 137 if agent_type not in configs: 138 configs[agent_type] = {} 139 140 configs[agent_type][config_level] = row['config'] 141 142 self.logger.info(f"[OUT] get_all_agent_configs: Found {len(configs)} types") 143 else: 144 self.logger.info("[OUT] get_all_agent_configs: No configs found (empty result)") 145 146 return {"success": True, "data": configs} 147 148 except Exception as e: 149 self.logger.error(f"[ERROR] get_all_agent_configs: {str(e)}") 150 return { 151 "success": False, 152 "error": f"Failed to get agent configs: {str(e)}", 153 "error_code": "CONFIG_FETCH_ERROR" 154 } 155 156 async def get_agent_config( 157 self, 158 agent_type: str, 159 requested_level: str = "default" 160 ) -> Dict[str, Any]: 161 """ 162 エージェント設定を取得(優先順位: user > requested_level > default) 163 164 Args: 165 agent_type: エージェントタイプ 166 requested_level: リクエストレベル ('default', 'high_performance') 167 168 Returns: 169 Dict with keys: success, data ({config, level}), error (optional) 170 """ 171 self.logger.info(f"[IN] get_agent_config: agent_type={agent_type}, level={requested_level}") 172 try: 173 # まずユーザー設定を確認 174 user_config_query = """ 175 SELECT config FROM agent_configurations 176 WHERE config_level = 'user' AND agent_type = $1 AND is_active = TRUE 177 """ 178 result = await self._db.execute_query(user_config_query, (agent_type,)) 179 180 # DBエラーを先にチェック 181 if not result.get("success"): 182 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 183 return { 184 "success": False, 185 "error": result.get("error", "Database query failed"), 186 "error_code": result.get("error_code", "DATABASE_ERROR") 187 } 188 189 if result.get("data"): 190 self.logger.info(f"[OUT] get_agent_config: Found user config for {agent_type}") 191 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "user"}} 192 193 # 次に要求されたレベルの設定を確認 194 level_config_query = """ 195 SELECT config FROM agent_configurations 196 WHERE config_level = $1 AND agent_type = $2 AND is_active = TRUE 197 """ 198 result = await self._db.execute_query(level_config_query, (requested_level, agent_type)) 199 200 if not result.get("success"): 201 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 202 return { 203 "success": False, 204 "error": result.get("error", "Database query failed"), 205 "error_code": result.get("error_code", "DATABASE_ERROR") 206 } 207 208 if result.get("data"): 209 self.logger.info(f"[OUT] get_agent_config: Found {requested_level} config for {agent_type}") 210 return {"success": True, "data": {"config": result["data"][0]['config'], "level": requested_level}} 211 212 # 最後にデフォルト設定を確認 213 if requested_level != "default": 214 result = await self._db.execute_query(level_config_query, ("default", agent_type)) 215 216 if not result.get("success"): 217 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 218 return { 219 "success": False, 220 "error": result.get("error", "Database query failed"), 221 "error_code": result.get("error_code", "DATABASE_ERROR") 222 } 223 224 if result.get("data"): 225 self.logger.info(f"[OUT] get_agent_config: Found default config for {agent_type}") 226 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "default"}} 227 228 self.logger.warning(f"[OUT] get_agent_config: No config found for {agent_type}") 229 return { 230 "success": False, 231 "error": f"No configuration found for agent: {agent_type}", 232 "error_code": "CONFIG_NOT_FOUND" 233 } 234 235 except Exception as e: 236 self.logger.error(f"[ERROR] get_agent_config: {str(e)}") 237 return { 238 "success": False, 239 "error": str(e), 240 "error_code": "DATABASE_ERROR" 241 } 242 243 # ===================================================== 244 # TOOL CONFIGURATIONS 245 # ===================================================== 246 247 async def get_tool_config(self, tool_name: str) -> Dict[str, Any]: 248 """ 249 ツール設定を取得 250 251 Args: 252 tool_name: ツール名 253 254 Returns: 255 Dict with keys: success, data ({config, tool_name}), error (optional) 256 """ 257 self.logger.info(f"[IN] get_tool_config: tool_name={tool_name}") 258 try: 259 result = await self._db.select( 260 table="tool_configurations", 261 columns=["config"], 262 where={"tool_name": tool_name, "is_active": True}, 263 ) 264 265 # DBエラーを先にチェック 266 if not result.get("success"): 267 self.logger.error(f"[ERROR] get_tool_config: DB error for {tool_name}") 268 return { 269 "success": False, 270 "error": result.get("error", "Database query failed"), 271 "error_code": result.get("error_code", "DATABASE_ERROR") 272 } 273 274 if result.get("data"): 275 self.logger.info(f"[OUT] get_tool_config: Found for {tool_name}") 276 return {"success": True, "data": {"config": result["data"][0]['config'], "tool_name": tool_name}} 277 278 self.logger.warning(f"[OUT] get_tool_config: Not found for {tool_name}") 279 return { 280 "success": False, 281 "error": f"No configuration found for tool: {tool_name}", 282 "error_code": "CONFIG_NOT_FOUND" 283 } 284 285 except Exception as e: 286 self.logger.error(f"[ERROR] get_tool_config: {str(e)}") 287 return { 288 "success": False, 289 "error": str(e), 290 "error_code": "DATABASE_ERROR" 291 } 292 293 async def get_all_tool_configs(self) -> Dict[str, Any]: 294 """ 295 全ツール設定を一括取得 296 297 Returns: 298 Dict with keys: success, data ({tool_name: config}), error (optional) 299 """ 300 self.logger.info("[IN] get_all_tool_configs") 301 try: 302 query = """ 303 SELECT tool_name, config 304 FROM tool_configurations 305 WHERE is_active = TRUE 306 ORDER BY tool_name 307 """ 308 309 result = await self._db.execute_query(query) 310 311 # DBエラーを先にチェック 312 if not result.get("success"): 313 self.logger.error("[ERROR] get_all_tool_configs: DB error") 314 return { 315 "success": False, 316 "error": result.get("error", "Database query failed"), 317 "error_code": result.get("error_code", "DATABASE_ERROR") 318 } 319 320 configs = {} 321 if result.get("data"): 322 for row in result["data"]: 323 configs[row['tool_name']] = row['config'] 324 self.logger.info(f"[OUT] get_all_tool_configs: Found {len(configs)} tools") 325 else: 326 self.logger.info("[OUT] get_all_tool_configs: No configs found (empty result)") 327 328 return {"success": True, "data": configs} 329 330 except Exception as e: 331 self.logger.error(f"[ERROR] get_all_tool_configs: {str(e)}") 332 return { 333 "success": False, 334 "error": f"Failed to get tool configs: {str(e)}", 335 "error_code": "CONFIG_FETCH_ERROR" 336 } 337 338 # ===================================================== 339 # GUARDRAILS SETTINGS 340 # ===================================================== 341 342 async def get_guardrails_settings(self) -> Dict[str, Any]: 343 """ 344 Guardrails設定を取得 345 346 Returns: 347 Dict with keys: success, data (settings dict) 348 """ 349 self.logger.info("[IN] get_guardrails_settings") 350 351 default_settings = { 352 "prompt_shield_enabled": True, 353 "moderation_threshold": 2, 354 "pii_masking_enabled": True 355 } 356 357 try: 358 result = await self._db.select( 359 table="guardrails_settings", 360 columns=["prompt_shield_enabled", "moderation_threshold", "pii_masking_enabled"], 361 where={"id": 1}, 362 ) 363 364 # DBエラーを先にチェック 365 if not result.get("success"): 366 self.logger.error("[ERROR] get_guardrails_settings: DB error, using defaults") 367 # Guardrailsはデフォルト値を返す(フォールバック) 368 return {"success": True, "data": default_settings} 369 370 if result.get("data"): 371 settings = result["data"][0] 372 self.logger.info("[OUT] get_guardrails_settings: Found") 373 return {"success": True, "data": settings} 374 else: 375 self.logger.info("[OUT] get_guardrails_settings: No settings found, using defaults") 376 return {"success": True, "data": default_settings} 377 378 except Exception as e: 379 self.logger.error(f"[ERROR] get_guardrails_settings: {str(e)}, using defaults") 380 return {"success": True, "data": default_settings} 381 382 # ===================================================== 383 # SYSTEM SETTINGS 384 # ===================================================== 385 386 async def get_system_settings(self) -> Dict[str, Any]: 387 """ 388 システム設定を取得(アクティブな全レコード) 389 390 Returns: 391 Dict with keys: success, data (list of settings) 392 """ 393 self.logger.info("[IN] get_system_settings") 394 try: 395 query = """ 396 SELECT user_prompt, guardrail_text, is_active, priority 397 FROM system_settings 398 WHERE organization_id IS NULL AND is_active = true 399 """ 400 result = await self._db.execute_query(query) 401 402 # DBエラーを先にチェック 403 if not result.get("success"): 404 self.logger.error("[ERROR] get_system_settings: DB error") 405 return { 406 "success": False, 407 "error": result.get("error", "Database query failed"), 408 "error_code": result.get("error_code", "DATABASE_ERROR") 409 } 410 411 if result.get("data"): 412 self.logger.info(f"[OUT] get_system_settings: Found {len(result['data'])} settings") 413 return {"success": True, "data": result["data"]} 414 else: 415 self.logger.info("[OUT] get_system_settings: No settings found (empty result)") 416 return {"success": True, "data": []} 417 418 except Exception as e: 419 self.logger.error(f"[ERROR] get_system_settings: {str(e)}") 420 return { 421 "success": False, 422 "error": f"Database error: {str(e)}", 423 "error_code": "DATABASE_ERROR" 424 } 425 426 # ===================================================== 427 # MCP CONFIGURATIONS 428 # ===================================================== 429 430 async def get_mcp_configurations(self, execution_mode: Optional[str] = None) -> Dict[str, Any]: 431 """ 432 アクティブなMCPサーバー設定を取得(OAuth情報含む) 433 434 Args: 435 execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得 436 437 Returns: 438 Dict with keys: success, data (list of MCP configs) 439 各MCP configには以下を含む: 440 - name, server_url, cache_tools_list, headers, only_pod 441 - use_oauth: OAuthを使用するかどうか 442 - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等) 443 """ 444 self.logger.info("[IN] get_mcp_configurations") 445 try: 446 # CLI系モードの場合、only_pod=falseまたはNULLのレコードのみ取得 447 # oauth_providersとLEFT JOINしてOAuth情報も取得 448 if self._is_cli_mode(execution_mode): 449 query = """ 450 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 451 m.use_oauth, o.service as oauth_service, m.transport_type 452 FROM mcp_server_configurations m 453 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 454 WHERE m.is_active = TRUE 455 AND (m.only_pod = FALSE OR m.only_pod IS NULL) 456 ORDER BY m.name 457 """ 458 self.logger.info("CLI mode: Filtering out pod-only MCP servers") 459 else: 460 query = """ 461 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 462 m.use_oauth, o.service as oauth_service, m.transport_type 463 FROM mcp_server_configurations m 464 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 465 WHERE m.is_active = TRUE 466 ORDER BY m.name 467 """ 468 469 result = await self._db.execute_query(query) 470 471 # DBエラーを先にチェック 472 if not result.get("success"): 473 self.logger.error("[ERROR] get_mcp_configurations: DB error") 474 return { 475 "success": False, 476 "error": result.get("error", "Database query failed"), 477 "error_code": result.get("error_code", "DATABASE_ERROR") 478 } 479 480 if result.get("data"): 481 configurations = [] 482 for row in result["data"]: 483 # headersをdictに変換 484 headers_raw = row['headers'] 485 headers = None 486 if headers_raw: 487 if isinstance(headers_raw, str): 488 try: 489 headers = json.loads(headers_raw) 490 except Exception as e: 491 self.logger.warning(f"Failed to parse headers for '{row['name']}': {e}") 492 headers = None 493 elif isinstance(headers_raw, dict): 494 headers = headers_raw 495 496 if headers == {}: 497 headers = None 498 499 config = { 500 "name": row['name'], 501 "server_url": row['server_url'], 502 "cache_tools_list": row['cache_tools_list'], 503 "headers": headers, 504 "only_pod": row.get('only_pod', False), 505 # OAuth情報 506 "use_oauth": row.get('use_oauth', False), 507 "oauth_service": row.get('oauth_service'), # github, slack, google, office365等 508 # Transport type: 'sse', 'streamable_http', 'stdio' 509 "transport_type": row.get('transport_type', 'sse'), 510 } 511 configurations.append(config) 512 513 self.logger.info(f"[OUT] get_mcp_configurations: Found {len(configurations)} MCP servers") 514 return {"success": True, "data": configurations} 515 else: 516 self.logger.info("[OUT] get_mcp_configurations: No MCP configs found (empty result)") 517 return {"success": True, "data": []} 518 519 except Exception as e: 520 self.logger.error(f"[ERROR] get_mcp_configurations: {str(e)}") 521 return { 522 "success": False, 523 "error": f"Database error: {str(e)}", 524 "error_code": "DATABASE_ERROR" 525 }
設定データアクセスクラス
各種設定テーブルからの読み込み機能を提供:
- エージェント指示・設定
- ツール設定
- ガードレール設定
- システム設定
- MCP設定
35 def __init__(self, db: DataAccess): 36 """ 37 Args: 38 db: DataAccessインスタンス 39 """ 40 self._db = db 41 self.logger = logger
Args: db: DataAccessインスタンス
52 async def get_agent_instructions(self, agent_type: str) -> Dict[str, Any]: 53 """ 54 エージェント指示を取得 55 56 Args: 57 agent_type: エージェントタイプ ('intent', 'react', 'coordinator') 58 59 Returns: 60 Dict with keys: success, data (instructions string), error (optional) 61 """ 62 self.logger.info(f"[IN] get_agent_instructions: agent_type={agent_type}") 63 try: 64 result = await self._db.select( 65 table="agent_instructions", 66 columns=["instructions"], 67 where={"agent_type": agent_type}, 68 ) 69 70 # DBエラーを先にチェック 71 if not result.get("success"): 72 self.logger.error(f"[ERROR] get_agent_instructions: DB error for {agent_type}") 73 return { 74 "success": False, 75 "error": result.get("error", "Database query failed"), 76 "error_code": result.get("error_code", "DATABASE_ERROR") 77 } 78 79 # データが存在するかチェック 80 if result.get("data"): 81 instructions = result["data"][0]["instructions"] 82 self.logger.info(f"[OUT] get_agent_instructions: Found for {agent_type}") 83 return {"success": True, "data": instructions} 84 else: 85 self.logger.warning(f"[OUT] get_agent_instructions: Not found for {agent_type}") 86 return { 87 "success": False, 88 "error": f"No instructions found for agent type: {agent_type}", 89 "error_code": "INSTRUCTIONS_NOT_FOUND" 90 } 91 92 except Exception as e: 93 self.logger.error(f"[ERROR] get_agent_instructions: {str(e)}") 94 return { 95 "success": False, 96 "error": f"Database error: {str(e)}", 97 "error_code": "DATABASE_ERROR" 98 }
エージェント指示を取得
Args: agent_type: エージェントタイプ ('intent', 'react', 'coordinator')
Returns: Dict with keys: success, data (instructions string), error (optional)
104 async def get_all_agent_configs(self) -> Dict[str, Any]: 105 """ 106 全エージェント設定を一括取得 107 108 Returns: 109 Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional) 110 """ 111 self.logger.info("[IN] get_all_agent_configs") 112 try: 113 query = """ 114 SELECT agent_type, config, config_level 115 FROM agent_configurations 116 WHERE is_active = TRUE 117 ORDER BY agent_type, config_level 118 """ 119 120 result = await self._db.execute_query(query) 121 122 # DBエラーを先にチェック 123 if not result.get("success"): 124 self.logger.error("[ERROR] get_all_agent_configs: DB error") 125 return { 126 "success": False, 127 "error": result.get("error", "Database query failed"), 128 "error_code": result.get("error_code", "DATABASE_ERROR") 129 } 130 131 configs = {} 132 if result.get("data"): 133 for row in result["data"]: 134 agent_type = row['agent_type'] 135 config_level = row['config_level'] 136 137 if agent_type not in configs: 138 configs[agent_type] = {} 139 140 configs[agent_type][config_level] = row['config'] 141 142 self.logger.info(f"[OUT] get_all_agent_configs: Found {len(configs)} types") 143 else: 144 self.logger.info("[OUT] get_all_agent_configs: No configs found (empty result)") 145 146 return {"success": True, "data": configs} 147 148 except Exception as e: 149 self.logger.error(f"[ERROR] get_all_agent_configs: {str(e)}") 150 return { 151 "success": False, 152 "error": f"Failed to get agent configs: {str(e)}", 153 "error_code": "CONFIG_FETCH_ERROR" 154 }
全エージェント設定を一括取得
Returns: Dict with keys: success, data ({agent_type: {config_level: config}}), error (optional)
156 async def get_agent_config( 157 self, 158 agent_type: str, 159 requested_level: str = "default" 160 ) -> Dict[str, Any]: 161 """ 162 エージェント設定を取得(優先順位: user > requested_level > default) 163 164 Args: 165 agent_type: エージェントタイプ 166 requested_level: リクエストレベル ('default', 'high_performance') 167 168 Returns: 169 Dict with keys: success, data ({config, level}), error (optional) 170 """ 171 self.logger.info(f"[IN] get_agent_config: agent_type={agent_type}, level={requested_level}") 172 try: 173 # まずユーザー設定を確認 174 user_config_query = """ 175 SELECT config FROM agent_configurations 176 WHERE config_level = 'user' AND agent_type = $1 AND is_active = TRUE 177 """ 178 result = await self._db.execute_query(user_config_query, (agent_type,)) 179 180 # DBエラーを先にチェック 181 if not result.get("success"): 182 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 183 return { 184 "success": False, 185 "error": result.get("error", "Database query failed"), 186 "error_code": result.get("error_code", "DATABASE_ERROR") 187 } 188 189 if result.get("data"): 190 self.logger.info(f"[OUT] get_agent_config: Found user config for {agent_type}") 191 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "user"}} 192 193 # 次に要求されたレベルの設定を確認 194 level_config_query = """ 195 SELECT config FROM agent_configurations 196 WHERE config_level = $1 AND agent_type = $2 AND is_active = TRUE 197 """ 198 result = await self._db.execute_query(level_config_query, (requested_level, agent_type)) 199 200 if not result.get("success"): 201 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 202 return { 203 "success": False, 204 "error": result.get("error", "Database query failed"), 205 "error_code": result.get("error_code", "DATABASE_ERROR") 206 } 207 208 if result.get("data"): 209 self.logger.info(f"[OUT] get_agent_config: Found {requested_level} config for {agent_type}") 210 return {"success": True, "data": {"config": result["data"][0]['config'], "level": requested_level}} 211 212 # 最後にデフォルト設定を確認 213 if requested_level != "default": 214 result = await self._db.execute_query(level_config_query, ("default", agent_type)) 215 216 if not result.get("success"): 217 self.logger.error(f"[ERROR] get_agent_config: DB error for {agent_type}") 218 return { 219 "success": False, 220 "error": result.get("error", "Database query failed"), 221 "error_code": result.get("error_code", "DATABASE_ERROR") 222 } 223 224 if result.get("data"): 225 self.logger.info(f"[OUT] get_agent_config: Found default config for {agent_type}") 226 return {"success": True, "data": {"config": result["data"][0]['config'], "level": "default"}} 227 228 self.logger.warning(f"[OUT] get_agent_config: No config found for {agent_type}") 229 return { 230 "success": False, 231 "error": f"No configuration found for agent: {agent_type}", 232 "error_code": "CONFIG_NOT_FOUND" 233 } 234 235 except Exception as e: 236 self.logger.error(f"[ERROR] get_agent_config: {str(e)}") 237 return { 238 "success": False, 239 "error": str(e), 240 "error_code": "DATABASE_ERROR" 241 }
エージェント設定を取得(優先順位: user > requested_level > default)
Args: agent_type: エージェントタイプ requested_level: リクエストレベル ('default', 'high_performance')
Returns: Dict with keys: success, data ({config, level}), error (optional)
247 async def get_tool_config(self, tool_name: str) -> Dict[str, Any]: 248 """ 249 ツール設定を取得 250 251 Args: 252 tool_name: ツール名 253 254 Returns: 255 Dict with keys: success, data ({config, tool_name}), error (optional) 256 """ 257 self.logger.info(f"[IN] get_tool_config: tool_name={tool_name}") 258 try: 259 result = await self._db.select( 260 table="tool_configurations", 261 columns=["config"], 262 where={"tool_name": tool_name, "is_active": True}, 263 ) 264 265 # DBエラーを先にチェック 266 if not result.get("success"): 267 self.logger.error(f"[ERROR] get_tool_config: DB error for {tool_name}") 268 return { 269 "success": False, 270 "error": result.get("error", "Database query failed"), 271 "error_code": result.get("error_code", "DATABASE_ERROR") 272 } 273 274 if result.get("data"): 275 self.logger.info(f"[OUT] get_tool_config: Found for {tool_name}") 276 return {"success": True, "data": {"config": result["data"][0]['config'], "tool_name": tool_name}} 277 278 self.logger.warning(f"[OUT] get_tool_config: Not found for {tool_name}") 279 return { 280 "success": False, 281 "error": f"No configuration found for tool: {tool_name}", 282 "error_code": "CONFIG_NOT_FOUND" 283 } 284 285 except Exception as e: 286 self.logger.error(f"[ERROR] get_tool_config: {str(e)}") 287 return { 288 "success": False, 289 "error": str(e), 290 "error_code": "DATABASE_ERROR" 291 }
ツール設定を取得
Args: tool_name: ツール名
Returns: Dict with keys: success, data ({config, tool_name}), error (optional)
293 async def get_all_tool_configs(self) -> Dict[str, Any]: 294 """ 295 全ツール設定を一括取得 296 297 Returns: 298 Dict with keys: success, data ({tool_name: config}), error (optional) 299 """ 300 self.logger.info("[IN] get_all_tool_configs") 301 try: 302 query = """ 303 SELECT tool_name, config 304 FROM tool_configurations 305 WHERE is_active = TRUE 306 ORDER BY tool_name 307 """ 308 309 result = await self._db.execute_query(query) 310 311 # DBエラーを先にチェック 312 if not result.get("success"): 313 self.logger.error("[ERROR] get_all_tool_configs: DB error") 314 return { 315 "success": False, 316 "error": result.get("error", "Database query failed"), 317 "error_code": result.get("error_code", "DATABASE_ERROR") 318 } 319 320 configs = {} 321 if result.get("data"): 322 for row in result["data"]: 323 configs[row['tool_name']] = row['config'] 324 self.logger.info(f"[OUT] get_all_tool_configs: Found {len(configs)} tools") 325 else: 326 self.logger.info("[OUT] get_all_tool_configs: No configs found (empty result)") 327 328 return {"success": True, "data": configs} 329 330 except Exception as e: 331 self.logger.error(f"[ERROR] get_all_tool_configs: {str(e)}") 332 return { 333 "success": False, 334 "error": f"Failed to get tool configs: {str(e)}", 335 "error_code": "CONFIG_FETCH_ERROR" 336 }
全ツール設定を一括取得
Returns: Dict with keys: success, data ({tool_name: config}), error (optional)
342 async def get_guardrails_settings(self) -> Dict[str, Any]: 343 """ 344 Guardrails設定を取得 345 346 Returns: 347 Dict with keys: success, data (settings dict) 348 """ 349 self.logger.info("[IN] get_guardrails_settings") 350 351 default_settings = { 352 "prompt_shield_enabled": True, 353 "moderation_threshold": 2, 354 "pii_masking_enabled": True 355 } 356 357 try: 358 result = await self._db.select( 359 table="guardrails_settings", 360 columns=["prompt_shield_enabled", "moderation_threshold", "pii_masking_enabled"], 361 where={"id": 1}, 362 ) 363 364 # DBエラーを先にチェック 365 if not result.get("success"): 366 self.logger.error("[ERROR] get_guardrails_settings: DB error, using defaults") 367 # Guardrailsはデフォルト値を返す(フォールバック) 368 return {"success": True, "data": default_settings} 369 370 if result.get("data"): 371 settings = result["data"][0] 372 self.logger.info("[OUT] get_guardrails_settings: Found") 373 return {"success": True, "data": settings} 374 else: 375 self.logger.info("[OUT] get_guardrails_settings: No settings found, using defaults") 376 return {"success": True, "data": default_settings} 377 378 except Exception as e: 379 self.logger.error(f"[ERROR] get_guardrails_settings: {str(e)}, using defaults") 380 return {"success": True, "data": default_settings}
Guardrails設定を取得
Returns: Dict with keys: success, data (settings dict)
386 async def get_system_settings(self) -> Dict[str, Any]: 387 """ 388 システム設定を取得(アクティブな全レコード) 389 390 Returns: 391 Dict with keys: success, data (list of settings) 392 """ 393 self.logger.info("[IN] get_system_settings") 394 try: 395 query = """ 396 SELECT user_prompt, guardrail_text, is_active, priority 397 FROM system_settings 398 WHERE organization_id IS NULL AND is_active = true 399 """ 400 result = await self._db.execute_query(query) 401 402 # DBエラーを先にチェック 403 if not result.get("success"): 404 self.logger.error("[ERROR] get_system_settings: DB error") 405 return { 406 "success": False, 407 "error": result.get("error", "Database query failed"), 408 "error_code": result.get("error_code", "DATABASE_ERROR") 409 } 410 411 if result.get("data"): 412 self.logger.info(f"[OUT] get_system_settings: Found {len(result['data'])} settings") 413 return {"success": True, "data": result["data"]} 414 else: 415 self.logger.info("[OUT] get_system_settings: No settings found (empty result)") 416 return {"success": True, "data": []} 417 418 except Exception as e: 419 self.logger.error(f"[ERROR] get_system_settings: {str(e)}") 420 return { 421 "success": False, 422 "error": f"Database error: {str(e)}", 423 "error_code": "DATABASE_ERROR" 424 }
システム設定を取得(アクティブな全レコード)
Returns: Dict with keys: success, data (list of settings)
430 async def get_mcp_configurations(self, execution_mode: Optional[str] = None) -> Dict[str, Any]: 431 """ 432 アクティブなMCPサーバー設定を取得(OAuth情報含む) 433 434 Args: 435 execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得 436 437 Returns: 438 Dict with keys: success, data (list of MCP configs) 439 各MCP configには以下を含む: 440 - name, server_url, cache_tools_list, headers, only_pod 441 - use_oauth: OAuthを使用するかどうか 442 - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等) 443 """ 444 self.logger.info("[IN] get_mcp_configurations") 445 try: 446 # CLI系モードの場合、only_pod=falseまたはNULLのレコードのみ取得 447 # oauth_providersとLEFT JOINしてOAuth情報も取得 448 if self._is_cli_mode(execution_mode): 449 query = """ 450 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 451 m.use_oauth, o.service as oauth_service, m.transport_type 452 FROM mcp_server_configurations m 453 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 454 WHERE m.is_active = TRUE 455 AND (m.only_pod = FALSE OR m.only_pod IS NULL) 456 ORDER BY m.name 457 """ 458 self.logger.info("CLI mode: Filtering out pod-only MCP servers") 459 else: 460 query = """ 461 SELECT m.name, m.server_url, m.cache_tools_list, m.headers, m.only_pod, 462 m.use_oauth, o.service as oauth_service, m.transport_type 463 FROM mcp_server_configurations m 464 LEFT JOIN oauth_providers o ON o.mcp_server_id = m.id AND o.is_enabled = TRUE 465 WHERE m.is_active = TRUE 466 ORDER BY m.name 467 """ 468 469 result = await self._db.execute_query(query) 470 471 # DBエラーを先にチェック 472 if not result.get("success"): 473 self.logger.error("[ERROR] get_mcp_configurations: DB error") 474 return { 475 "success": False, 476 "error": result.get("error", "Database query failed"), 477 "error_code": result.get("error_code", "DATABASE_ERROR") 478 } 479 480 if result.get("data"): 481 configurations = [] 482 for row in result["data"]: 483 # headersをdictに変換 484 headers_raw = row['headers'] 485 headers = None 486 if headers_raw: 487 if isinstance(headers_raw, str): 488 try: 489 headers = json.loads(headers_raw) 490 except Exception as e: 491 self.logger.warning(f"Failed to parse headers for '{row['name']}': {e}") 492 headers = None 493 elif isinstance(headers_raw, dict): 494 headers = headers_raw 495 496 if headers == {}: 497 headers = None 498 499 config = { 500 "name": row['name'], 501 "server_url": row['server_url'], 502 "cache_tools_list": row['cache_tools_list'], 503 "headers": headers, 504 "only_pod": row.get('only_pod', False), 505 # OAuth情報 506 "use_oauth": row.get('use_oauth', False), 507 "oauth_service": row.get('oauth_service'), # github, slack, google, office365等 508 # Transport type: 'sse', 'streamable_http', 'stdio' 509 "transport_type": row.get('transport_type', 'sse'), 510 } 511 configurations.append(config) 512 513 self.logger.info(f"[OUT] get_mcp_configurations: Found {len(configurations)} MCP servers") 514 return {"success": True, "data": configurations} 515 else: 516 self.logger.info("[OUT] get_mcp_configurations: No MCP configs found (empty result)") 517 return {"success": True, "data": []} 518 519 except Exception as e: 520 self.logger.error(f"[ERROR] get_mcp_configurations: {str(e)}") 521 return { 522 "success": False, 523 "error": f"Database error: {str(e)}", 524 "error_code": "DATABASE_ERROR" 525 }
アクティブなMCPサーバー設定を取得(OAuth情報含む)
Args: execution_mode: 実行モード ('cli', 'cli-api', 'pod')。Noneの場合は環境変数から取得
Returns: Dict with keys: success, data (list of MCP configs) 各MCP configには以下を含む: - name, server_url, cache_tools_list, headers, only_pod - use_oauth: OAuthを使用するかどうか - oauth_service: 紐づいたOAuthプロバイダーのservice名(github, slack等)
151class EmbeddingGenerator: 152 """ 153 Azure OpenAI Embedding生成クラス 154 155 テキストをベクトル表現に変換します。 156 LRUキャッシュとレートリミット対応により、効率的かつ安定的にEmbeddingを生成します。 157 158 Features: 159 - 非同期バッチEmbedding生成(batch_generate) 160 - LRUメモリキャッシュによる重複リクエスト削減 161 - 指数バックオフ+ジッターによるレートリミット対応 162 - Azure OpenAI text-embedding-ada-002モデル対応 163 - キャッシュ統計とキャッシュクリア機能 164 165 Example: 166 >>> from agenticstar_platform import EmbeddingGenerator, EmbeddingConfig 167 >>> 168 >>> # 設定を作成 169 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 170 >>> 171 >>> # クライアント初期化 172 >>> generator = EmbeddingGenerator(config) 173 >>> 174 >>> # 単一テキストのEmbedding生成 175 >>> embedding = await generator.generate("Azure OpenAIの使い方を教えてください") 176 >>> print(f"Embedding dimensions: {len(embedding)}") # 1536 177 >>> 178 >>> # バッチEmbedding生成(効率的) 179 >>> texts = ["テキスト1", "テキスト2", "テキスト3"] 180 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 181 >>> print(f"Generated {len(embeddings)} embeddings") 182 >>> 183 >>> # キャッシュ統計確認 184 >>> stats = generator.get_cache_stats() 185 >>> print(f"Cache utilization: {stats['cache_utilization']:.1%}") 186 """ 187 188 def __init__(self, config: EmbeddingConfig): 189 """ 190 Initialize Embedding Generator 191 192 Args: 193 config: EmbeddingConfig instance 194 """ 195 self.client = AsyncAzureOpenAI( 196 azure_endpoint=config.base_url, 197 api_key=config.api_key, 198 api_version=config.api_version, 199 ) 200 self.deployment_name = config.model 201 self.model_name = config.model # For model_version tracking 202 self.dimensions = config.dimensions 203 204 # リトライ設定 205 self._max_retries = config.max_retries 206 self._base_delay = config.base_delay 207 self._max_delay = config.max_delay 208 209 # LRU Cache 210 self._memory_cache: Dict[str, List[float]] = {} 211 self._cache_order: List[str] = [] # LRU tracking 212 self._max_cache_size = config.max_cache_size 213 214 logger.info( 215 f"EmbeddingGenerator initialized - model={config.model}, " 216 f"cache_size={config.max_cache_size}, max_retries={config.max_retries}" 217 ) 218 219 async def generate(self, text: str) -> List[float]: 220 """ 221 単一テキストのEmbeddingを生成(キャッシュ対応) 222 223 同一テキストは自動的にキャッシュからEmbeddingを取得します。 224 レートリミットに達した場合は自動リトライを行います。 225 226 Args: 227 text: Embedding対象のテキスト 228 229 Returns: 230 List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元) 231 232 Raises: 233 RateLimitExceededError: レートリミット超過でリトライ上限到達時 234 EmbeddingError: その他のAPIエラー 235 236 Example: 237 >>> # 単一テキストのEmbedding生成 238 >>> embedding = await generator.generate("質問に回答する方法") 239 >>> print(f"Dimensions: {len(embedding)}") # 1536 240 >>> 241 >>> # キャッシュ済みの場合は即座に返る 242 >>> embedding_cached = await generator.generate("質問に回答する方法") 243 """ 244 # Cache key 245 cache_key = self._make_cache_key(text) 246 247 # Cache check 248 if cache_key in self._memory_cache: 249 logger.debug(f"Cache hit for key={cache_key[:16]}...") 250 return self._memory_cache[cache_key] 251 252 # API call with retry 253 last_error = None 254 for attempt in range(self._max_retries): 255 try: 256 logger.debug(f"Generating embedding - text_len={len(text)} chars, attempt={attempt + 1}") 257 response = await self.client.embeddings.create( 258 input=text, model=self.deployment_name 259 ) 260 embedding = response.data[0].embedding 261 262 # Cache update 263 self._update_memory_cache(cache_key, embedding) 264 265 logger.debug(f"Embedding generated - dims={len(embedding)}") 266 return embedding 267 268 except Exception as e: 269 last_error = e 270 if "429" in str(e) or "rate" in str(e).lower(): 271 delay = self._calculate_backoff_delay(attempt) 272 logger.warning( 273 f"Rate limit hit (attempt {attempt + 1}/{self._max_retries}), " 274 f"retrying after {delay:.1f}s..." 275 ) 276 await asyncio.sleep(delay) 277 continue 278 else: 279 logger.error(f"Embedding generation failed: {e}") 280 raise EmbeddingError(f"Embedding generation failed: {e}") from e 281 282 # リトライ上限到達 283 raise RateLimitExceededError( 284 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 285 ) 286 287 async def batch_generate( 288 self, texts: List[str], batch_size: int = 16 289 ) -> List[List[float]]: 290 """ 291 複数テキストのEmbeddingを一括生成(バッチ処理) 292 293 大量のテキストを効率的にEmbedding化します。 294 キャッシュ済みのテキストはスキップし、APIコール数を削減します。 295 296 Args: 297 texts: Embedding対象のテキストリスト 298 batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる 299 300 Returns: 301 List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序) 302 303 Raises: 304 RateLimitExceededError: レートリミット超過でリトライ上限到達時 305 EmbeddingError: その他のAPIエラー 306 307 Example: 308 >>> # バッチEmbedding生成 309 >>> texts = [ 310 ... "ドキュメント1の内容", 311 ... "ドキュメント2の内容", 312 ... "ドキュメント3の内容" 313 ... ] 314 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 315 >>> print(f"Generated {len(embeddings)} embeddings") 316 >>> 317 >>> # 各Embeddingを確認 318 >>> for i, emb in enumerate(embeddings): 319 ... print(f"Text {i}: {len(emb)} dimensions") 320 """ 321 logger.info( 322 f"Batch embedding generation - texts={len(texts)}, " 323 f"batch_size={batch_size}" 324 ) 325 326 # Check cache 327 embeddings: List[Optional[List[float]]] = [] 328 uncached_texts: List[str] = [] 329 uncached_indices: List[int] = [] 330 331 for i, text in enumerate(texts): 332 cache_key = self._make_cache_key(text) 333 if cache_key in self._memory_cache: 334 embeddings.append(self._memory_cache[cache_key]) 335 else: 336 uncached_texts.append(text) 337 uncached_indices.append(i) 338 embeddings.append(None) # Placeholder 339 340 logger.debug( 341 f"Cache hits: {len(texts) - len(uncached_texts)}/{len(texts)}" 342 ) 343 344 # Batch process uncached texts 345 if uncached_texts: 346 i = 0 347 while i < len(uncached_texts): 348 batch = uncached_texts[i : i + batch_size] 349 batch_indices = uncached_indices[i : i + batch_size] 350 351 last_error = None 352 for attempt in range(self._max_retries): 353 try: 354 response = await self.client.embeddings.create( 355 input=batch, model=self.deployment_name 356 ) 357 358 # Map back to original indices 359 for j, embedding_obj in enumerate(response.data): 360 original_index = batch_indices[j] 361 embedding = embedding_obj.embedding 362 embeddings[original_index] = embedding 363 364 # Cache update 365 cache_key = self._make_cache_key(batch[j]) 366 self._update_memory_cache(cache_key, embedding) 367 368 # Success - move to next batch 369 break 370 371 except Exception as e: 372 last_error = e 373 if "429" in str(e) or "rate" in str(e).lower(): 374 delay = self._calculate_backoff_delay(attempt) 375 logger.warning( 376 f"Rate limit hit in batch (attempt {attempt + 1}/{self._max_retries}), " 377 f"retrying after {delay:.1f}s..." 378 ) 379 await asyncio.sleep(delay) 380 continue 381 else: 382 logger.error(f"Batch embedding failed: {e}") 383 raise EmbeddingError(f"Batch embedding failed: {e}") from e 384 else: 385 # リトライ上限到達 386 raise RateLimitExceededError( 387 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 388 ) 389 390 i += batch_size 391 392 return embeddings # type: ignore 393 394 def _calculate_backoff_delay(self, attempt: int) -> float: 395 """ 396 指数バックオフ + ジッターでディレイを計算 397 398 Args: 399 attempt: 現在のリトライ回数 (0-indexed) 400 401 Returns: 402 待機時間(秒) 403 """ 404 # 指数バックオフ: base_delay * 2^attempt 405 delay = self._base_delay * (2 ** attempt) 406 # 上限適用 407 delay = min(delay, self._max_delay) 408 # ジッター追加 (±25%) 409 jitter = delay * 0.25 * (2 * random.random() - 1) 410 return delay + jitter 411 412 def _make_cache_key(self, text: str) -> str: 413 """ 414 Generate cache key (model_name + text_hash) 415 416 Args: 417 text: Input text 418 419 Returns: 420 Cache key string 421 """ 422 text_hash = hashlib.sha256(text.encode()).hexdigest()[:16] 423 return f"{self.model_name}:{text_hash}" 424 425 def _update_memory_cache(self, key: str, embedding: List[float]) -> None: 426 """ 427 Update LRU cache 428 429 Args: 430 key: Cache key 431 embedding: Embedding vector 432 """ 433 if key in self._memory_cache: 434 # Update order for existing key 435 self._cache_order.remove(key) 436 elif len(self._memory_cache) >= self._max_cache_size: 437 # Evict oldest entry 438 oldest_key = self._cache_order.pop(0) 439 del self._memory_cache[oldest_key] 440 logger.debug(f"Cache evicted - key={oldest_key[:16]}...") 441 442 self._memory_cache[key] = embedding 443 self._cache_order.append(key) 444 445 def get_cache_stats(self) -> Dict[str, Any]: 446 """ 447 キャッシュ統計情報を取得 448 449 キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。 450 451 Returns: 452 Dict[str, Any]: キャッシュ統計情報 453 - cache_size: 現在のキャッシュエントリ数 454 - max_cache_size: 最大キャッシュサイズ 455 - cache_utilization: キャッシュ使用率(0.0〜1.0) 456 - model_name: 使用中のモデル名 457 458 Example: 459 >>> stats = generator.get_cache_stats() 460 >>> print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") 461 >>> print(f"Utilization: {stats['cache_utilization']:.1%}") 462 >>> if stats['cache_utilization'] > 0.9: 463 ... generator.clear_cache() 464 """ 465 return { 466 "cache_size": len(self._memory_cache), 467 "max_cache_size": self._max_cache_size, 468 "cache_utilization": len(self._memory_cache) / self._max_cache_size if self._max_cache_size > 0 else 0, 469 "model_name": self.model_name, 470 } 471 472 def clear_cache(self) -> None: 473 """ 474 キャッシュを全クリア 475 476 メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。 477 478 Example: 479 >>> # キャッシュ統計確認 480 >>> stats = generator.get_cache_stats() 481 >>> print(f"Before: {stats['cache_size']} entries") 482 >>> 483 >>> # キャッシュクリア 484 >>> generator.clear_cache() 485 >>> print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0 486 """ 487 self._memory_cache.clear() 488 self._cache_order.clear() 489 logger.info("Embedding cache cleared")
Azure OpenAI Embedding生成クラス
テキストをベクトル表現に変換します。 LRUキャッシュとレートリミット対応により、効率的かつ安定的にEmbeddingを生成します。
Features: - 非同期バッチEmbedding生成(batch_generate) - LRUメモリキャッシュによる重複リクエスト削減 - 指数バックオフ+ジッターによるレートリミット対応 - Azure OpenAI text-embedding-ada-002モデル対応 - キャッシュ統計とキャッシュクリア機能
Example:
from agenticstar_platform import EmbeddingGenerator, EmbeddingConfig
設定を作成
config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding")
クライアント初期化
generator = EmbeddingGenerator(config)
単一テキストのEmbedding生成
embedding = await generator.generate("Azure OpenAIの使い方を教えてください") print(f"Embedding dimensions: {len(embedding)}") # 1536
バッチEmbedding生成(効率的)
texts = ["テキスト1", "テキスト2", "テキスト3"] embeddings = await generator.batch_generate(texts, batch_size=16) print(f"Generated {len(embeddings)} embeddings")
キャッシュ統計確認
stats = generator.get_cache_stats() print(f"Cache utilization: {stats['cache_utilization']:.1%}")
188 def __init__(self, config: EmbeddingConfig): 189 """ 190 Initialize Embedding Generator 191 192 Args: 193 config: EmbeddingConfig instance 194 """ 195 self.client = AsyncAzureOpenAI( 196 azure_endpoint=config.base_url, 197 api_key=config.api_key, 198 api_version=config.api_version, 199 ) 200 self.deployment_name = config.model 201 self.model_name = config.model # For model_version tracking 202 self.dimensions = config.dimensions 203 204 # リトライ設定 205 self._max_retries = config.max_retries 206 self._base_delay = config.base_delay 207 self._max_delay = config.max_delay 208 209 # LRU Cache 210 self._memory_cache: Dict[str, List[float]] = {} 211 self._cache_order: List[str] = [] # LRU tracking 212 self._max_cache_size = config.max_cache_size 213 214 logger.info( 215 f"EmbeddingGenerator initialized - model={config.model}, " 216 f"cache_size={config.max_cache_size}, max_retries={config.max_retries}" 217 )
Initialize Embedding Generator
Args: config: EmbeddingConfig instance
219 async def generate(self, text: str) -> List[float]: 220 """ 221 単一テキストのEmbeddingを生成(キャッシュ対応) 222 223 同一テキストは自動的にキャッシュからEmbeddingを取得します。 224 レートリミットに達した場合は自動リトライを行います。 225 226 Args: 227 text: Embedding対象のテキスト 228 229 Returns: 230 List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元) 231 232 Raises: 233 RateLimitExceededError: レートリミット超過でリトライ上限到達時 234 EmbeddingError: その他のAPIエラー 235 236 Example: 237 >>> # 単一テキストのEmbedding生成 238 >>> embedding = await generator.generate("質問に回答する方法") 239 >>> print(f"Dimensions: {len(embedding)}") # 1536 240 >>> 241 >>> # キャッシュ済みの場合は即座に返る 242 >>> embedding_cached = await generator.generate("質問に回答する方法") 243 """ 244 # Cache key 245 cache_key = self._make_cache_key(text) 246 247 # Cache check 248 if cache_key in self._memory_cache: 249 logger.debug(f"Cache hit for key={cache_key[:16]}...") 250 return self._memory_cache[cache_key] 251 252 # API call with retry 253 last_error = None 254 for attempt in range(self._max_retries): 255 try: 256 logger.debug(f"Generating embedding - text_len={len(text)} chars, attempt={attempt + 1}") 257 response = await self.client.embeddings.create( 258 input=text, model=self.deployment_name 259 ) 260 embedding = response.data[0].embedding 261 262 # Cache update 263 self._update_memory_cache(cache_key, embedding) 264 265 logger.debug(f"Embedding generated - dims={len(embedding)}") 266 return embedding 267 268 except Exception as e: 269 last_error = e 270 if "429" in str(e) or "rate" in str(e).lower(): 271 delay = self._calculate_backoff_delay(attempt) 272 logger.warning( 273 f"Rate limit hit (attempt {attempt + 1}/{self._max_retries}), " 274 f"retrying after {delay:.1f}s..." 275 ) 276 await asyncio.sleep(delay) 277 continue 278 else: 279 logger.error(f"Embedding generation failed: {e}") 280 raise EmbeddingError(f"Embedding generation failed: {e}") from e 281 282 # リトライ上限到達 283 raise RateLimitExceededError( 284 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 285 )
単一テキストのEmbeddingを生成(キャッシュ対応)
同一テキストは自動的にキャッシュからEmbeddingを取得します。 レートリミットに達した場合は自動リトライを行います。
Args: text: Embedding対象のテキスト
Returns: List[float]: Embeddingベクトル(text-embedding-ada-002の場合は1536次元)
Raises: RateLimitExceededError: レートリミット超過でリトライ上限到達時 EmbeddingError: その他のAPIエラー
Example:
単一テキストのEmbedding生成
embedding = await generator.generate("質問に回答する方法") print(f"Dimensions: {len(embedding)}") # 1536
キャッシュ済みの場合は即座に返る
embedding_cached = await generator.generate("質問に回答する方法")
287 async def batch_generate( 288 self, texts: List[str], batch_size: int = 16 289 ) -> List[List[float]]: 290 """ 291 複数テキストのEmbeddingを一括生成(バッチ処理) 292 293 大量のテキストを効率的にEmbedding化します。 294 キャッシュ済みのテキストはスキップし、APIコール数を削減します。 295 296 Args: 297 texts: Embedding対象のテキストリスト 298 batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる 299 300 Returns: 301 List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序) 302 303 Raises: 304 RateLimitExceededError: レートリミット超過でリトライ上限到達時 305 EmbeddingError: その他のAPIエラー 306 307 Example: 308 >>> # バッチEmbedding生成 309 >>> texts = [ 310 ... "ドキュメント1の内容", 311 ... "ドキュメント2の内容", 312 ... "ドキュメント3の内容" 313 ... ] 314 >>> embeddings = await generator.batch_generate(texts, batch_size=16) 315 >>> print(f"Generated {len(embeddings)} embeddings") 316 >>> 317 >>> # 各Embeddingを確認 318 >>> for i, emb in enumerate(embeddings): 319 ... print(f"Text {i}: {len(emb)} dimensions") 320 """ 321 logger.info( 322 f"Batch embedding generation - texts={len(texts)}, " 323 f"batch_size={batch_size}" 324 ) 325 326 # Check cache 327 embeddings: List[Optional[List[float]]] = [] 328 uncached_texts: List[str] = [] 329 uncached_indices: List[int] = [] 330 331 for i, text in enumerate(texts): 332 cache_key = self._make_cache_key(text) 333 if cache_key in self._memory_cache: 334 embeddings.append(self._memory_cache[cache_key]) 335 else: 336 uncached_texts.append(text) 337 uncached_indices.append(i) 338 embeddings.append(None) # Placeholder 339 340 logger.debug( 341 f"Cache hits: {len(texts) - len(uncached_texts)}/{len(texts)}" 342 ) 343 344 # Batch process uncached texts 345 if uncached_texts: 346 i = 0 347 while i < len(uncached_texts): 348 batch = uncached_texts[i : i + batch_size] 349 batch_indices = uncached_indices[i : i + batch_size] 350 351 last_error = None 352 for attempt in range(self._max_retries): 353 try: 354 response = await self.client.embeddings.create( 355 input=batch, model=self.deployment_name 356 ) 357 358 # Map back to original indices 359 for j, embedding_obj in enumerate(response.data): 360 original_index = batch_indices[j] 361 embedding = embedding_obj.embedding 362 embeddings[original_index] = embedding 363 364 # Cache update 365 cache_key = self._make_cache_key(batch[j]) 366 self._update_memory_cache(cache_key, embedding) 367 368 # Success - move to next batch 369 break 370 371 except Exception as e: 372 last_error = e 373 if "429" in str(e) or "rate" in str(e).lower(): 374 delay = self._calculate_backoff_delay(attempt) 375 logger.warning( 376 f"Rate limit hit in batch (attempt {attempt + 1}/{self._max_retries}), " 377 f"retrying after {delay:.1f}s..." 378 ) 379 await asyncio.sleep(delay) 380 continue 381 else: 382 logger.error(f"Batch embedding failed: {e}") 383 raise EmbeddingError(f"Batch embedding failed: {e}") from e 384 else: 385 # リトライ上限到達 386 raise RateLimitExceededError( 387 f"Rate limit exceeded after {self._max_retries} retries: {last_error}" 388 ) 389 390 i += batch_size 391 392 return embeddings # type: ignore
複数テキストのEmbeddingを一括生成(バッチ処理)
大量のテキストを効率的にEmbedding化します。 キャッシュ済みのテキストはスキップし、APIコール数を削減します。
Args: texts: Embedding対象のテキストリスト batch_size: APIバッチサイズ(推奨: 16-32)。大きすぎるとレートリミットに達しやすくなる
Returns: List[List[float]]: Embeddingベクトルのリスト(入力と同じ順序)
Raises: RateLimitExceededError: レートリミット超過でリトライ上限到達時 EmbeddingError: その他のAPIエラー
Example:
バッチEmbedding生成
texts = [ ... "ドキュメント1の内容", ... "ドキュメント2の内容", ... "ドキュメント3の内容" ... ] embeddings = await generator.batch_generate(texts, batch_size=16) print(f"Generated {len(embeddings)} embeddings")
各Embeddingを確認
for i, emb in enumerate(embeddings): ... print(f"Text {i}: {len(emb)} dimensions")
445 def get_cache_stats(self) -> Dict[str, Any]: 446 """ 447 キャッシュ統計情報を取得 448 449 キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。 450 451 Returns: 452 Dict[str, Any]: キャッシュ統計情報 453 - cache_size: 現在のキャッシュエントリ数 454 - max_cache_size: 最大キャッシュサイズ 455 - cache_utilization: キャッシュ使用率(0.0〜1.0) 456 - model_name: 使用中のモデル名 457 458 Example: 459 >>> stats = generator.get_cache_stats() 460 >>> print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") 461 >>> print(f"Utilization: {stats['cache_utilization']:.1%}") 462 >>> if stats['cache_utilization'] > 0.9: 463 ... generator.clear_cache() 464 """ 465 return { 466 "cache_size": len(self._memory_cache), 467 "max_cache_size": self._max_cache_size, 468 "cache_utilization": len(self._memory_cache) / self._max_cache_size if self._max_cache_size > 0 else 0, 469 "model_name": self.model_name, 470 }
キャッシュ統計情報を取得
キャッシュの使用状況を確認し、必要に応じてクリアするか判断できます。
Returns: Dict[str, Any]: キャッシュ統計情報 - cache_size: 現在のキャッシュエントリ数 - max_cache_size: 最大キャッシュサイズ - cache_utilization: キャッシュ使用率(0.0〜1.0) - model_name: 使用中のモデル名
Example:
stats = generator.get_cache_stats() print(f"Cache: {stats['cache_size']}/{stats['max_cache_size']}") print(f"Utilization: {stats['cache_utilization']:.1%}") if stats['cache_utilization'] > 0.9: ... generator.clear_cache()
472 def clear_cache(self) -> None: 473 """ 474 キャッシュを全クリア 475 476 メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。 477 478 Example: 479 >>> # キャッシュ統計確認 480 >>> stats = generator.get_cache_stats() 481 >>> print(f"Before: {stats['cache_size']} entries") 482 >>> 483 >>> # キャッシュクリア 484 >>> generator.clear_cache() 485 >>> print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0 486 """ 487 self._memory_cache.clear() 488 self._cache_order.clear() 489 logger.info("Embedding cache cleared")
キャッシュを全クリア
メモリ使用量が増えた場合や、新しいモデルに切り替える場合に使用します。
Example:
キャッシュ統計確認
stats = generator.get_cache_stats() print(f"Before: {stats['cache_size']} entries")
キャッシュクリア
generator.clear_cache() print(f"After: {generator.get_cache_stats()['cache_size']} entries") # 0
38@dataclass 39class EmbeddingConfig: 40 """Embedding設定 41 42 Azure OpenAI Embedding APIへの接続設定を管理します。 43 44 Attributes: 45 base_url: Azure OpenAI エンドポイントURL(例: "https://your-resource.openai.azure.com/") 46 api_key: APIキー 47 model: デプロイメント名/モデル名(例: "text-embedding-3-small") 48 api_version: APIバージョン(デフォルト: "2024-02-15-preview") 49 dimensions: Embedding次元数(text-embedding-3-smallは1536) 50 max_cache_size: キャッシュ最大サイズ(デフォルト: 10000) 51 max_retries: リトライ回数(デフォルト: 5) 52 base_delay: 基本待機時間秒(デフォルト: 1.0) 53 max_delay: 最大待機時間秒(デフォルト: 60.0) 54 55 Example: 56 >>> # 標準的な作成方法 57 >>> config = EmbeddingConfig( 58 ... base_url="https://your-resource.openai.azure.com/", 59 ... api_key="your-api-key", 60 ... model="text-embedding-3-small", 61 ... dimensions=1536, 62 ... ) 63 64 >>> # 辞書から作成 65 >>> config = EmbeddingConfig.from_dict({ 66 ... "base_url": "https://your-resource.openai.azure.com/", 67 ... "api_key": "your-api-key", 68 ... "model": "text-embedding-3-small" 69 ... }) 70 71 >>> # TOMLファイルから作成 72 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 73 """ 74 # 標準パラメータ名(業界標準に準拠) 75 base_url: str 76 api_key: str 77 model: str = "text-embedding-ada-002" 78 api_version: str = "2024-02-15-preview" 79 max_cache_size: int = 10000 80 dimensions: int = 1536 # text-embedding-ada-002のデフォルト 81 # リトライ設定 82 max_retries: int = 5 83 base_delay: float = 1.0 84 max_delay: float = 60.0 85 86 @classmethod 87 def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingConfig": 88 """辞書からEmbeddingConfigを作成 89 90 Args: 91 data: 設定辞書(base_url, api_key, model 等) 92 93 Returns: 94 EmbeddingConfig instance 95 96 Example: 97 >>> config = EmbeddingConfig.from_dict({ 98 ... "base_url": "https://your-resource.openai.azure.com/", 99 ... "api_key": "your-api-key", 100 ... "model": "text-embedding-3-small", 101 ... "dimensions": 1536 102 ... }) 103 """ 104 base_url = data.get("base_url") 105 if not base_url: 106 raise ValueError("'base_url' is required") 107 108 model = data.get("model", "text-embedding-ada-002") 109 110 return cls( 111 base_url=base_url, 112 api_key=data["api_key"], 113 model=model, 114 api_version=data.get("api_version", "2024-02-15-preview"), 115 max_cache_size=data.get("max_cache_size", 10000), 116 dimensions=data.get("dimensions", 1536), 117 max_retries=data.get("max_retries", 5), 118 base_delay=data.get("base_delay", 1.0), 119 max_delay=data.get("max_delay", 60.0), 120 ) 121 122 @classmethod 123 def from_toml(cls, toml_path: str, section: str = "rag.embedding") -> "EmbeddingConfig": 124 """TOMLファイルからEmbeddingConfigを作成 125 126 Args: 127 toml_path: TOMLファイルパス 128 section: セクション名(ドット区切りでネスト対応) 129 130 Returns: 131 EmbeddingConfig instance 132 133 Example: 134 >>> # config.toml の [rag.embedding] セクションを読み込み 135 >>> config = EmbeddingConfig.from_toml("config.toml") 136 """ 137 if not TOML_AVAILABLE: 138 raise ImportError("toml library is required: pip install toml") 139 140 with open(toml_path, "r") as f: 141 toml_data = toml.load(f) 142 143 # ネストされたセクションをドット区切りで辿る 144 data = toml_data 145 for key in section.split("."): 146 data = data[key] 147 148 return cls.from_dict(data)
Embedding設定
Azure OpenAI Embedding APIへの接続設定を管理します。
Attributes: base_url: Azure OpenAI エンドポイントURL(例: "https://your-resource.openai.azure.com/") api_key: APIキー model: デプロイメント名/モデル名(例: "text-embedding-3-small") api_version: APIバージョン(デフォルト: "2024-02-15-preview") dimensions: Embedding次元数(text-embedding-3-smallは1536) max_cache_size: キャッシュ最大サイズ(デフォルト: 10000) max_retries: リトライ回数(デフォルト: 5) base_delay: 基本待機時間秒(デフォルト: 1.0) max_delay: 最大待機時間秒(デフォルト: 60.0)
Example:
標準的な作成方法
config = EmbeddingConfig( ... base_url="https://your-resource.openai.azure.com/", ... api_key="your-api-key", ... model="text-embedding-3-small", ... dimensions=1536, ... )
>>> # 辞書から作成 >>> config = EmbeddingConfig.from_dict({ ... "base_url": "https://your-resource.openai.azure.com/", ... "api_key": "your-api-key", ... "model": "text-embedding-3-small" ... }) >>> # TOMLファイルから作成 >>> config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding")
86 @classmethod 87 def from_dict(cls, data: Dict[str, Any]) -> "EmbeddingConfig": 88 """辞書からEmbeddingConfigを作成 89 90 Args: 91 data: 設定辞書(base_url, api_key, model 等) 92 93 Returns: 94 EmbeddingConfig instance 95 96 Example: 97 >>> config = EmbeddingConfig.from_dict({ 98 ... "base_url": "https://your-resource.openai.azure.com/", 99 ... "api_key": "your-api-key", 100 ... "model": "text-embedding-3-small", 101 ... "dimensions": 1536 102 ... }) 103 """ 104 base_url = data.get("base_url") 105 if not base_url: 106 raise ValueError("'base_url' is required") 107 108 model = data.get("model", "text-embedding-ada-002") 109 110 return cls( 111 base_url=base_url, 112 api_key=data["api_key"], 113 model=model, 114 api_version=data.get("api_version", "2024-02-15-preview"), 115 max_cache_size=data.get("max_cache_size", 10000), 116 dimensions=data.get("dimensions", 1536), 117 max_retries=data.get("max_retries", 5), 118 base_delay=data.get("base_delay", 1.0), 119 max_delay=data.get("max_delay", 60.0), 120 )
辞書からEmbeddingConfigを作成
Args: data: 設定辞書(base_url, api_key, model 等)
Returns: EmbeddingConfig instance
Example:
config = EmbeddingConfig.from_dict({ ... "base_url": "https://your-resource.openai.azure.com/", ... "api_key": "your-api-key", ... "model": "text-embedding-3-small", ... "dimensions": 1536 ... })
122 @classmethod 123 def from_toml(cls, toml_path: str, section: str = "rag.embedding") -> "EmbeddingConfig": 124 """TOMLファイルからEmbeddingConfigを作成 125 126 Args: 127 toml_path: TOMLファイルパス 128 section: セクション名(ドット区切りでネスト対応) 129 130 Returns: 131 EmbeddingConfig instance 132 133 Example: 134 >>> # config.toml の [rag.embedding] セクションを読み込み 135 >>> config = EmbeddingConfig.from_toml("config.toml") 136 """ 137 if not TOML_AVAILABLE: 138 raise ImportError("toml library is required: pip install toml") 139 140 with open(toml_path, "r") as f: 141 toml_data = toml.load(f) 142 143 # ネストされたセクションをドット区切りで辿る 144 data = toml_data 145 for key in section.split("."): 146 data = data[key] 147 148 return cls.from_dict(data)
TOMLファイルからEmbeddingConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: EmbeddingConfig instance
Example:
config.toml の [rag.embedding] セクションを読み込み
config = EmbeddingConfig.from_toml("config.toml")
172class QdrantManager: 173 """ 174 Qdrant操作クラス 175 176 ベクトルデータベースQdrantへのアクセスを提供します。 177 テキストの埋め込み生成とベクトル検索を統合的に扱えます。 178 179 Features: 180 - Collection自動作成(HNSW最適化設定) 181 - Payload Index作成(検索フィルタ用) 182 - Upsert/Batch Upsert (冪等性保証 - 同一IDは上書き) 183 - Semantic Search (フィルタ対応) 184 - 統計情報取得 185 186 Example: 187 >>> from agenticstar_platform import QdrantManager, QdrantConfig, EmbeddingGenerator, EmbeddingConfig 188 >>> 189 >>> # 設定を作成 190 >>> embedding_config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") 191 >>> qdrant_config = QdrantConfig.from_toml("config.toml", section="rag.qdrant") 192 >>> 193 >>> # Embeddingジェネレーター作成 194 >>> embedding_gen = EmbeddingGenerator(embedding_config) 195 >>> 196 >>> # コンテキストマネージャーで使用(推奨) 197 >>> async with QdrantManager(qdrant_config, embedding_gen) as manager: 198 ... # データ追加 199 ... await manager.upsert( 200 ... id="doc-001", 201 ... content="AIエージェントの設計パターン", 202 ... metadata={"category": "technical", "author": "John"} 203 ... ) 204 ... 205 ... # 検索 206 ... results = await manager.search( 207 ... query_text="エージェント設計", 208 ... limit=5, 209 ... filter_conditions={"category": "technical"} 210 ... ) 211 ... for hit in results["data"]["results"]: 212 ... print(f"Score: {hit['score']:.3f} - {hit['payload']}") 213 """ 214 215 def __init__( 216 self, 217 config: QdrantConfig, 218 embedding_generator: EmbeddingGenerator, 219 ): 220 """ 221 Initialize Qdrant Manager 222 223 Args: 224 config: QdrantConfig instance 225 embedding_generator: EmbeddingGenerator instance 226 227 Raises: 228 QdrantConfigError: vector_sizeとembedding dimensionsの不一致 229 """ 230 # vector_size検証 231 if config.vector_size != embedding_generator.dimensions: 232 raise QdrantConfigError( 233 f"Vector size mismatch: QdrantConfig.vector_size={config.vector_size} " 234 f"!= EmbeddingConfig.dimensions={embedding_generator.dimensions}. " 235 f"These values must match for proper vector storage." 236 ) 237 238 # auth_token_provider設定時はプロキシ経由 239 # QdrantClientはurlパラメータ使用時にポート6333へデフォルト接続するため、 240 # CLIモード(プロキシ経由)ではhost/port/https/prefixで明示的に指定する必要がある 241 if config.auth_token_provider is not None: 242 # URLをパースしてhost/port/https/prefixを抽出 243 parsed = urlparse(config.url) 244 host = parsed.hostname 245 port = parsed.port or (443 if parsed.scheme == "https" else 80) 246 https = parsed.scheme == "https" 247 # パスからprefixを抽出(先頭の/を除去) 248 prefix = parsed.path.lstrip("/") if parsed.path and parsed.path != "/" else None 249 250 logger.info( 251 f"QdrantManager CLI mode - host={host}, port={port}, " 252 f"https={https}, prefix={prefix}" 253 ) 254 255 self.client = QdrantClient( 256 host=host, 257 port=port, 258 https=https, 259 prefix=prefix, 260 auth_token_provider=config.auth_token_provider, 261 prefer_grpc=False, # プロキシ経由のためgRPC無効化 262 check_compatibility=False, # バージョンチェック無効化(プロキシ経由では不要) 263 ) 264 else: 265 # 直接接続(Podモード等) 266 self.client = QdrantClient( 267 url=config.url, 268 prefer_grpc=True, 269 ) 270 271 self.config = config 272 self.collection_name = config.collection_name 273 self.embedding_generator = embedding_generator 274 self._initialized = False 275 276 logger.info( 277 f"QdrantManager initialized - url={config.url}, " 278 f"collection={config.collection_name}, vector_size={config.vector_size}" 279 ) 280 281 async def __aenter__(self) -> "QdrantManager": 282 """Context Manager: 開始時に初期化""" 283 await self.initialize() 284 return self 285 286 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 287 """Context Manager: 終了時にクローズ""" 288 await self.close() 289 290 def is_initialized(self) -> bool: 291 """初期化されているかどうか""" 292 return self._initialized 293 294 async def ensure_initialized(self) -> None: 295 """初期化されていなければ初期化""" 296 if not self._initialized: 297 await self.initialize() 298 299 async def initialize(self) -> None: 300 """ 301 Collection存在確認・作成・Index作成 302 """ 303 if self._initialized: 304 logger.debug("QdrantManager already initialized") 305 return 306 307 logger.info(f"Initializing QdrantManager - collection={self.collection_name}") 308 309 try: 310 # Collection existence check (sync call → asyncio.to_thread) 311 collections_response = await asyncio.to_thread( 312 self.client.get_collections 313 ) 314 collection_names = [col.name for col in collections_response.collections] 315 316 if self.collection_name not in collection_names: 317 logger.info(f"Creating collection: {self.collection_name}") 318 await self._create_collection() 319 await self._create_payload_indices() 320 else: 321 logger.info(f"Collection already exists: {self.collection_name}") 322 323 self._initialized = True 324 logger.info("QdrantManager initialized successfully") 325 326 except Exception as e: 327 logger.error(f"Failed to initialize QdrantManager: {e}") 328 raise 329 330 async def _create_collection(self) -> None: 331 """ 332 Create Qdrant collection with optimized settings 333 """ 334 # Distance mapping 335 distance_map = { 336 "cosine": Distance.COSINE, 337 "euclid": Distance.EUCLID, 338 "dot": Distance.DOT, 339 } 340 distance = distance_map.get(self.config.distance, Distance.COSINE) 341 342 await asyncio.to_thread( 343 self.client.create_collection, 344 collection_name=self.collection_name, 345 vectors_config=VectorParams( 346 size=self.config.vector_size, 347 distance=distance, 348 on_disk=self.config.on_disk_vectors, 349 ), 350 hnsw_config=HnswConfigDiff( 351 m=self.config.hnsw_m, 352 ef_construct=self.config.hnsw_ef_construct, 353 full_scan_threshold=10000, 354 on_disk=True, 355 ), 356 optimizers_config=OptimizersConfigDiff( 357 memmap_threshold=20000, 358 ), 359 on_disk_payload=self.config.on_disk_payload, 360 ) 361 logger.info(f"Collection created: {self.collection_name}") 362 363 async def _create_payload_indices(self) -> None: 364 """ 365 Create payload indices for filtering performance (configurable) 366 """ 367 for index_config in self.config.payload_indexes: 368 await asyncio.to_thread( 369 self.client.create_payload_index, 370 collection_name=self.collection_name, 371 field_name=index_config.field_name, 372 field_schema=index_config.field_schema, 373 ) 374 logger.info(f"Payload index created: {index_config.field_name}") 375 376 async def create_payload_index( 377 self, 378 field_name: str, 379 field_schema: str = "keyword" 380 ) -> Dict[str, Any]: 381 """ 382 カスタムPayload Indexを作成 383 384 Args: 385 field_name: フィールド名 386 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool") 387 388 Returns: 389 Dict with success status 390 """ 391 try: 392 await asyncio.to_thread( 393 self.client.create_payload_index, 394 collection_name=self.collection_name, 395 field_name=field_name, 396 field_schema=field_schema, 397 ) 398 logger.info(f"Payload index created: {field_name}") 399 return {"success": True, "field_name": field_name} 400 except Exception as e: 401 logger.error(f"Failed to create payload index: {e}") 402 return {"success": False, "error": str(e), "error_code": "INDEX_CREATE_ERROR"} 403 404 async def upsert( 405 self, 406 id: Optional[str] = None, 407 content: Optional[str] = None, 408 metadata: Optional[Dict[str, Any]] = None, 409 ) -> Dict[str, Any]: 410 """ 411 ベクトルデータをUpsert(冪等性保証) 412 413 テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 414 同一IDが既に存在する場合は上書きされます(冪等性保証)。 415 416 Args: 417 id: ドキュメントID(省略時は自動UUID生成) 418 content: 埋め込み対象のテキスト 419 metadata: メタデータ(検索時のフィルタリングや結果表示に使用) 420 421 Returns: 422 Dict[str, Any]: 実行結果 423 - success: bool 424 - data: {"id": str} (成功時) 425 - error: str, error_code: str (失敗時) 426 427 Example: 428 >>> result = await manager.upsert( 429 ... id="doc-001", # 省略可(自動UUID生成) 430 ... content="Pythonは汎用プログラミング言語です。", 431 ... metadata={"title": "Python入門", "category": "programming"} 432 ... ) 433 >>> if result["success"]: 434 ... print(f"Added: {result['data']['id']}") 435 436 >>> # ID自動生成 437 >>> result = await manager.upsert( 438 ... content="テキスト内容", 439 ... metadata={"source": "web"} 440 ... ) 441 >>> print(f"Auto-generated ID: {result['data']['id']}") 442 """ 443 await self.ensure_initialized() 444 445 import uuid as uuid_module 446 actual_id = id or str(uuid_module.uuid4()) 447 actual_metadata = metadata or {} 448 449 if not content: 450 return { 451 "success": False, 452 "error": "'content' is required", 453 "error_code": "MISSING_CONTENT", 454 } 455 456 try: 457 logger.debug(f"Upserting point - id={actual_id}") 458 459 # Generate embedding 460 embedding = await self.embedding_generator.generate(content) 461 462 # Add default metadata 463 full_payload = { 464 **actual_metadata, 465 "created_at": int(datetime.now().timestamp()), 466 "model_version": self.embedding_generator.model_name, 467 } 468 469 # Build Point 470 point = PointStruct( 471 id=actual_id, 472 vector=embedding, 473 payload=full_payload, 474 ) 475 476 # Upsert (idempotent - same ID overwrites, sync call → asyncio.to_thread) 477 await asyncio.to_thread( 478 self.client.upsert, 479 collection_name=self.collection_name, 480 points=[point], 481 ) 482 483 logger.info(f"Point upserted - id={actual_id}") 484 return {"success": True, "data": {"id": actual_id}} 485 486 except Exception as e: 487 logger.error(f"Failed to upsert point: {e}") 488 return { 489 "success": False, 490 "error": f"Upsert error: {str(e)}", 491 "error_code": "QDRANT_UPSERT_ERROR", 492 } 493 494 async def batch_upsert( 495 self, 496 items: List[Dict[str, Any]], 497 batch_size: int = 256, 498 ) -> Dict[str, Any]: 499 """ 500 バッチUpsert 501 502 Args: 503 items: List of {"id": str, "content": str, "metadata": dict} 504 batch_size: バッチサイズ 505 506 Returns: 507 Dict with success status and total_upserted 508 """ 509 await self.ensure_initialized() 510 511 try: 512 logger.info(f"Batch upserting {len(items)} points") 513 514 total_points = 0 515 516 for i in range(0, len(items), batch_size): 517 batch = items[i : i + batch_size] 518 519 # Extract texts and generate embeddings 520 texts = [item["content"] for item in batch] 521 embeddings = await self.embedding_generator.batch_generate(texts) 522 523 # Build Points 524 points = [] 525 for item, embedding in zip(batch, embeddings): 526 full_payload = { 527 **item.get("metadata", {}), 528 "created_at": int(datetime.now().timestamp()), 529 "model_version": self.embedding_generator.model_name, 530 } 531 points.append( 532 PointStruct( 533 id=item["id"], 534 vector=embedding, 535 payload=full_payload, 536 ) 537 ) 538 539 # Batch Upsert (sync call → asyncio.to_thread) 540 await asyncio.to_thread( 541 self.client.upsert, 542 collection_name=self.collection_name, 543 points=points, 544 ) 545 total_points += len(points) 546 547 logger.debug(f"Batch upserted {len(points)} points") 548 549 logger.info(f"Batch upsert completed - total={total_points}") 550 return {"success": True, "data": {"total_upserted": total_points}} 551 552 except Exception as e: 553 logger.error(f"Failed to batch upsert: {e}") 554 return { 555 "success": False, 556 "error": f"Batch upsert error: {str(e)}", 557 "error_code": "QDRANT_BATCH_UPSERT_ERROR", 558 } 559 560 async def search( 561 self, 562 query_text: str, 563 limit: int = 10, 564 score_threshold: float = 0.4, 565 filter_conditions: Optional[Dict[str, Any]] = None, 566 ef: Optional[int] = None, 567 ) -> Dict[str, Any]: 568 """ 569 セマンティック検索 570 571 クエリテキストに意味的に近いドキュメントを検索します。 572 フィルタ条件を指定して結果を絞り込むことも可能です。 573 574 Args: 575 query_text: 検索クエリ。自然言語で記述 576 limit: 結果数上限(デフォルト: 10) 577 score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) 578 filter_conditions: フィルタ条件 {"field": value} 形式(オプション) 579 ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる 580 581 Returns: 582 Dict[str, Any]: 検索結果 583 - success: bool 584 - data: { 585 "query": str, 586 "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], 587 "total_found": int, 588 "score_threshold": float 589 } 590 - error: str, error_code: str (失敗時) 591 592 Example: 593 >>> # 基本的な検索 594 >>> results = await manager.search( 595 ... query_text="プログラミング言語", 596 ... limit=5 597 ... ) 598 >>> for hit in results["data"]["results"]: 599 ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}") 600 >>> 601 >>> # フィルタ付き検索 602 >>> results = await manager.search( 603 ... query_text="データ分析", 604 ... limit=10, 605 ... score_threshold=0.5, 606 ... filter_conditions={"category": "programming"} 607 ... ) 608 >>> print(f"Found {results['data']['total_found']} results") 609 """ 610 await self.ensure_initialized() 611 612 try: 613 logger.debug( 614 f"Searching - query_len={len(query_text)}, limit={limit}" 615 ) 616 617 # Generate query embedding 618 query_vector = await self.embedding_generator.generate(query_text) 619 620 # Build filter 621 query_filter = None 622 if filter_conditions: 623 must_conditions = [ 624 FieldCondition( 625 key=key, 626 match=MatchValue(value=value), 627 ) 628 for key, value in filter_conditions.items() 629 ] 630 query_filter = Filter(must=must_conditions) 631 632 # Search params 633 search_params = SearchParams( 634 hnsw_ef=ef or (limit * 2), 635 exact=False, 636 ) 637 638 # Execute search (sync call → asyncio.to_thread) 639 results = await asyncio.to_thread( 640 self.client.query_points, 641 collection_name=self.collection_name, 642 query=query_vector, 643 query_filter=query_filter, 644 limit=limit, 645 score_threshold=score_threshold, 646 search_params=search_params, 647 with_payload=True, 648 ) 649 650 # Format results 651 formatted_results = [ 652 { 653 "point_id": hit.id, 654 "payload": hit.payload, 655 "score": hit.score, 656 "similarity": self._score_to_similarity(hit.score), 657 } 658 for hit in results.points 659 ] 660 661 logger.info( 662 f"Search completed - found={len(formatted_results)}, " 663 f"query_len={len(query_text)}" 664 ) 665 return { 666 "success": True, 667 "data": { 668 "query": query_text, 669 "results": formatted_results, 670 "total_found": len(formatted_results), 671 "score_threshold": score_threshold, 672 }, 673 } 674 675 except Exception as e: 676 logger.error(f"Failed to search: {e}") 677 return { 678 "success": False, 679 "error": f"Search error: {str(e)}", 680 "error_code": "QDRANT_SEARCH_ERROR", 681 } 682 683 def _score_to_similarity(self, score: float) -> float: 684 """ 685 Qdrant cosine score (-1〜1) → 類似度 (0〜1) 変換 686 """ 687 return max(0.0, min(1.0, (score + 1.0) / 2.0)) 688 689 async def delete(self, point_ids: List[str]) -> Dict[str, Any]: 690 """ 691 ポイントを削除 692 693 指定したIDのベクトルデータを削除します。 694 695 Args: 696 point_ids: 削除するポイントIDのリスト 697 698 Returns: 699 Dict[str, Any]: 削除結果 700 - success: bool 701 - data: {"deleted_count": int} (成功時) 702 - error: str, error_code: str (失敗時) 703 704 Example: 705 >>> # 単一ポイントを削除 706 >>> result = await manager.delete(["doc-001"]) 707 >>> print(f"Deleted: {result['data']['deleted_count']}") 708 >>> 709 >>> # 複数ポイントを削除 710 >>> result = await manager.delete(["doc-001", "doc-002", "doc-003"]) 711 >>> if result["success"]: 712 ... print(f"Deleted {result['data']['deleted_count']} documents") 713 """ 714 await self.ensure_initialized() 715 716 try: 717 # Sync call → asyncio.to_thread 718 await asyncio.to_thread( 719 self.client.delete, 720 collection_name=self.collection_name, 721 points_selector=point_ids, 722 ) 723 logger.info(f"Deleted {len(point_ids)} points") 724 return {"success": True, "data": {"deleted_count": len(point_ids)}} 725 726 except Exception as e: 727 logger.error(f"Failed to delete points: {e}") 728 return { 729 "success": False, 730 "error": f"Delete error: {str(e)}", 731 "error_code": "QDRANT_DELETE_ERROR", 732 } 733 734 async def get_statistics(self) -> Dict[str, Any]: 735 """ 736 コレクション統計情報を取得 737 738 コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。 739 740 Returns: 741 Dict[str, Any]: 統計情報 742 - success: bool 743 - data: { 744 "total_objects": int, 745 "vectors_count": int, 746 "content_type_breakdown": dict, 747 "collection_name": str, 748 "status": str 749 } 750 - error: str, error_code: str (失敗時) 751 752 Example: 753 >>> stats = await manager.get_statistics() 754 >>> if stats["success"]: 755 ... data = stats["data"] 756 ... print(f"Total documents: {data['total_objects']}") 757 ... print(f"Status: {data['status']}") 758 ... print("Content types:") 759 ... for ctype, count in data["content_type_breakdown"].items(): 760 ... print(f" {ctype}: {count}") 761 """ 762 await self.ensure_initialized() 763 764 try: 765 # Sync calls → asyncio.to_thread 766 collection_info = await asyncio.to_thread( 767 self.client.get_collection, 768 self.collection_name, 769 ) 770 771 # Content type breakdown (sample 1000 points) 772 scroll_result = await asyncio.to_thread( 773 self.client.scroll, 774 collection_name=self.collection_name, 775 limit=1000, 776 with_payload=["content_type"], 777 with_vectors=False, 778 ) 779 780 content_type_breakdown = {} 781 for point in scroll_result[0]: 782 content_type = point.payload.get("content_type", "unknown") 783 content_type_breakdown[content_type] = ( 784 content_type_breakdown.get(content_type, 0) + 1 785 ) 786 787 stats = { 788 "total_objects": collection_info.points_count, 789 "vectors_count": collection_info.vectors_count, 790 "content_type_breakdown": content_type_breakdown, 791 "collection_name": self.collection_name, 792 "status": collection_info.status.value, 793 } 794 795 logger.info(f"Statistics retrieved - total={stats['total_objects']}") 796 return {"success": True, "data": stats} 797 798 except Exception as e: 799 logger.error(f"Failed to get statistics: {e}") 800 return { 801 "success": False, 802 "error": f"Statistics error: {str(e)}", 803 "error_code": "QDRANT_STATS_ERROR", 804 } 805 806 async def delete_collection(self, collection_name: Optional[str] = None) -> Dict[str, Any]: 807 """ 808 コレクションを削除 809 810 Args: 811 collection_name: 削除するコレクション名(省略時は現在のコレクション) 812 813 Returns: 814 Dict[str, Any]: 削除結果 815 - success: bool 816 - data: {"collection_name": str} (成功時) 817 - error: str, error_code: str (失敗時) 818 819 Example: 820 >>> # 現在のコレクションを削除 821 >>> result = await manager.delete_collection() 822 >>> if result["success"]: 823 ... print(f"Deleted: {result['data']['collection_name']}") 824 825 >>> # 別のコレクションを削除 826 >>> result = await manager.delete_collection("old_collection") 827 """ 828 target_collection = collection_name or self.collection_name 829 830 try: 831 await asyncio.to_thread( 832 self.client.delete_collection, 833 target_collection, 834 ) 835 logger.info(f"Collection deleted: {target_collection}") 836 837 # 現在のコレクションを削除した場合は初期化状態をリセット 838 if target_collection == self.collection_name: 839 self._initialized = False 840 841 return {"success": True, "data": {"collection_name": target_collection}} 842 843 except Exception as e: 844 logger.error(f"Failed to delete collection: {e}") 845 return { 846 "success": False, 847 "error": f"Delete collection error: {str(e)}", 848 "error_code": "QDRANT_DELETE_COLLECTION_ERROR", 849 } 850 851 async def collection_exists(self, collection_name: Optional[str] = None) -> bool: 852 """ 853 コレクションが存在するか確認 854 855 Args: 856 collection_name: 確認するコレクション名(省略時は現在のコレクション) 857 858 Returns: 859 bool: コレクションが存在する場合True 860 861 Example: 862 >>> if await manager.collection_exists(): 863 ... print("Collection exists") 864 >>> else: 865 ... print("Collection does not exist") 866 """ 867 target_collection = collection_name or self.collection_name 868 869 try: 870 collections_response = await asyncio.to_thread( 871 self.client.get_collections 872 ) 873 collection_names = [col.name for col in collections_response.collections] 874 return target_collection in collection_names 875 876 except Exception as e: 877 logger.error(f"Failed to check collection existence: {e}") 878 return False 879 880 async def close(self) -> None: 881 """ 882 Close Qdrant client 883 884 Note: 885 例外が発生してもリソースを確実にクリーンアップします。 886 """ 887 try: 888 logger.info("Closing QdrantManager") 889 # Qdrantクライアントのクローズ処理(必要に応じて) 890 if self.client and hasattr(self.client, 'close'): 891 self.client.close() 892 except Exception as e: 893 logger.warning(f"Error closing Qdrant client: {e}") 894 finally: 895 self.client = None 896 self._initialized = False 897 logger.debug("QdrantManager closed")
Qdrant操作クラス
ベクトルデータベースQdrantへのアクセスを提供します。 テキストの埋め込み生成とベクトル検索を統合的に扱えます。
Features: - Collection自動作成(HNSW最適化設定) - Payload Index作成(検索フィルタ用) - Upsert/Batch Upsert (冪等性保証 - 同一IDは上書き) - Semantic Search (フィルタ対応) - 統計情報取得
Example:
from agenticstar_platform import QdrantManager, QdrantConfig, EmbeddingGenerator, EmbeddingConfig
設定を作成
embedding_config = EmbeddingConfig.from_toml("config.toml", section="rag.embedding") qdrant_config = QdrantConfig.from_toml("config.toml", section="rag.qdrant")
Embeddingジェネレーター作成
embedding_gen = EmbeddingGenerator(embedding_config)
コンテキストマネージャーで使用(推奨)
async with QdrantManager(qdrant_config, embedding_gen) as manager: ... # データ追加 ... await manager.upsert( ... id="doc-001", ... content="AIエージェントの設計パターン", ... metadata={"category": "technical", "author": "John"} ... ) ... ... # 検索 ... results = await manager.search( ... query_text="エージェント設計", ... limit=5, ... filter_conditions={"category": "technical"} ... ) ... for hit in results["data"]["results"]: ... print(f"Score: {hit['score']:.3f} - {hit['payload']}")
215 def __init__( 216 self, 217 config: QdrantConfig, 218 embedding_generator: EmbeddingGenerator, 219 ): 220 """ 221 Initialize Qdrant Manager 222 223 Args: 224 config: QdrantConfig instance 225 embedding_generator: EmbeddingGenerator instance 226 227 Raises: 228 QdrantConfigError: vector_sizeとembedding dimensionsの不一致 229 """ 230 # vector_size検証 231 if config.vector_size != embedding_generator.dimensions: 232 raise QdrantConfigError( 233 f"Vector size mismatch: QdrantConfig.vector_size={config.vector_size} " 234 f"!= EmbeddingConfig.dimensions={embedding_generator.dimensions}. " 235 f"These values must match for proper vector storage." 236 ) 237 238 # auth_token_provider設定時はプロキシ経由 239 # QdrantClientはurlパラメータ使用時にポート6333へデフォルト接続するため、 240 # CLIモード(プロキシ経由)ではhost/port/https/prefixで明示的に指定する必要がある 241 if config.auth_token_provider is not None: 242 # URLをパースしてhost/port/https/prefixを抽出 243 parsed = urlparse(config.url) 244 host = parsed.hostname 245 port = parsed.port or (443 if parsed.scheme == "https" else 80) 246 https = parsed.scheme == "https" 247 # パスからprefixを抽出(先頭の/を除去) 248 prefix = parsed.path.lstrip("/") if parsed.path and parsed.path != "/" else None 249 250 logger.info( 251 f"QdrantManager CLI mode - host={host}, port={port}, " 252 f"https={https}, prefix={prefix}" 253 ) 254 255 self.client = QdrantClient( 256 host=host, 257 port=port, 258 https=https, 259 prefix=prefix, 260 auth_token_provider=config.auth_token_provider, 261 prefer_grpc=False, # プロキシ経由のためgRPC無効化 262 check_compatibility=False, # バージョンチェック無効化(プロキシ経由では不要) 263 ) 264 else: 265 # 直接接続(Podモード等) 266 self.client = QdrantClient( 267 url=config.url, 268 prefer_grpc=True, 269 ) 270 271 self.config = config 272 self.collection_name = config.collection_name 273 self.embedding_generator = embedding_generator 274 self._initialized = False 275 276 logger.info( 277 f"QdrantManager initialized - url={config.url}, " 278 f"collection={config.collection_name}, vector_size={config.vector_size}" 279 )
Initialize Qdrant Manager
Args: config: QdrantConfig instance embedding_generator: EmbeddingGenerator instance
Raises: QdrantConfigError: vector_sizeとembedding dimensionsの不一致
294 async def ensure_initialized(self) -> None: 295 """初期化されていなければ初期化""" 296 if not self._initialized: 297 await self.initialize()
初期化されていなければ初期化
299 async def initialize(self) -> None: 300 """ 301 Collection存在確認・作成・Index作成 302 """ 303 if self._initialized: 304 logger.debug("QdrantManager already initialized") 305 return 306 307 logger.info(f"Initializing QdrantManager - collection={self.collection_name}") 308 309 try: 310 # Collection existence check (sync call → asyncio.to_thread) 311 collections_response = await asyncio.to_thread( 312 self.client.get_collections 313 ) 314 collection_names = [col.name for col in collections_response.collections] 315 316 if self.collection_name not in collection_names: 317 logger.info(f"Creating collection: {self.collection_name}") 318 await self._create_collection() 319 await self._create_payload_indices() 320 else: 321 logger.info(f"Collection already exists: {self.collection_name}") 322 323 self._initialized = True 324 logger.info("QdrantManager initialized successfully") 325 326 except Exception as e: 327 logger.error(f"Failed to initialize QdrantManager: {e}") 328 raise
Collection存在確認・作成・Index作成
376 async def create_payload_index( 377 self, 378 field_name: str, 379 field_schema: str = "keyword" 380 ) -> Dict[str, Any]: 381 """ 382 カスタムPayload Indexを作成 383 384 Args: 385 field_name: フィールド名 386 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool") 387 388 Returns: 389 Dict with success status 390 """ 391 try: 392 await asyncio.to_thread( 393 self.client.create_payload_index, 394 collection_name=self.collection_name, 395 field_name=field_name, 396 field_schema=field_schema, 397 ) 398 logger.info(f"Payload index created: {field_name}") 399 return {"success": True, "field_name": field_name} 400 except Exception as e: 401 logger.error(f"Failed to create payload index: {e}") 402 return {"success": False, "error": str(e), "error_code": "INDEX_CREATE_ERROR"}
カスタムPayload Indexを作成
Args: field_name: フィールド名 field_schema: スキーマタイプ ("keyword", "integer", "float", "bool")
Returns: Dict with success status
404 async def upsert( 405 self, 406 id: Optional[str] = None, 407 content: Optional[str] = None, 408 metadata: Optional[Dict[str, Any]] = None, 409 ) -> Dict[str, Any]: 410 """ 411 ベクトルデータをUpsert(冪等性保証) 412 413 テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 414 同一IDが既に存在する場合は上書きされます(冪等性保証)。 415 416 Args: 417 id: ドキュメントID(省略時は自動UUID生成) 418 content: 埋め込み対象のテキスト 419 metadata: メタデータ(検索時のフィルタリングや結果表示に使用) 420 421 Returns: 422 Dict[str, Any]: 実行結果 423 - success: bool 424 - data: {"id": str} (成功時) 425 - error: str, error_code: str (失敗時) 426 427 Example: 428 >>> result = await manager.upsert( 429 ... id="doc-001", # 省略可(自動UUID生成) 430 ... content="Pythonは汎用プログラミング言語です。", 431 ... metadata={"title": "Python入門", "category": "programming"} 432 ... ) 433 >>> if result["success"]: 434 ... print(f"Added: {result['data']['id']}") 435 436 >>> # ID自動生成 437 >>> result = await manager.upsert( 438 ... content="テキスト内容", 439 ... metadata={"source": "web"} 440 ... ) 441 >>> print(f"Auto-generated ID: {result['data']['id']}") 442 """ 443 await self.ensure_initialized() 444 445 import uuid as uuid_module 446 actual_id = id or str(uuid_module.uuid4()) 447 actual_metadata = metadata or {} 448 449 if not content: 450 return { 451 "success": False, 452 "error": "'content' is required", 453 "error_code": "MISSING_CONTENT", 454 } 455 456 try: 457 logger.debug(f"Upserting point - id={actual_id}") 458 459 # Generate embedding 460 embedding = await self.embedding_generator.generate(content) 461 462 # Add default metadata 463 full_payload = { 464 **actual_metadata, 465 "created_at": int(datetime.now().timestamp()), 466 "model_version": self.embedding_generator.model_name, 467 } 468 469 # Build Point 470 point = PointStruct( 471 id=actual_id, 472 vector=embedding, 473 payload=full_payload, 474 ) 475 476 # Upsert (idempotent - same ID overwrites, sync call → asyncio.to_thread) 477 await asyncio.to_thread( 478 self.client.upsert, 479 collection_name=self.collection_name, 480 points=[point], 481 ) 482 483 logger.info(f"Point upserted - id={actual_id}") 484 return {"success": True, "data": {"id": actual_id}} 485 486 except Exception as e: 487 logger.error(f"Failed to upsert point: {e}") 488 return { 489 "success": False, 490 "error": f"Upsert error: {str(e)}", 491 "error_code": "QDRANT_UPSERT_ERROR", 492 }
ベクトルデータをUpsert(冪等性保証)
テキストからEmbeddingを生成し、ベクトルデータベースに保存します。 同一IDが既に存在する場合は上書きされます(冪等性保証)。
Args: id: ドキュメントID(省略時は自動UUID生成) content: 埋め込み対象のテキスト metadata: メタデータ(検索時のフィルタリングや結果表示に使用)
Returns: Dict[str, Any]: 実行結果 - success: bool - data: {"id": str} (成功時) - error: str, error_code: str (失敗時)
Example:
result = await manager.upsert( ... id="doc-001", # 省略可(自動UUID生成) ... content="Pythonは汎用プログラミング言語です。", ... metadata={"title": "Python入門", "category": "programming"} ... ) if result["success"]: ... print(f"Added: {result['data']['id']}")
>>> # ID自動生成 >>> result = await manager.upsert( ... content="テキスト内容", ... metadata={"source": "web"} ... ) >>> print(f"Auto-generated ID: {result['data']['id']}")
494 async def batch_upsert( 495 self, 496 items: List[Dict[str, Any]], 497 batch_size: int = 256, 498 ) -> Dict[str, Any]: 499 """ 500 バッチUpsert 501 502 Args: 503 items: List of {"id": str, "content": str, "metadata": dict} 504 batch_size: バッチサイズ 505 506 Returns: 507 Dict with success status and total_upserted 508 """ 509 await self.ensure_initialized() 510 511 try: 512 logger.info(f"Batch upserting {len(items)} points") 513 514 total_points = 0 515 516 for i in range(0, len(items), batch_size): 517 batch = items[i : i + batch_size] 518 519 # Extract texts and generate embeddings 520 texts = [item["content"] for item in batch] 521 embeddings = await self.embedding_generator.batch_generate(texts) 522 523 # Build Points 524 points = [] 525 for item, embedding in zip(batch, embeddings): 526 full_payload = { 527 **item.get("metadata", {}), 528 "created_at": int(datetime.now().timestamp()), 529 "model_version": self.embedding_generator.model_name, 530 } 531 points.append( 532 PointStruct( 533 id=item["id"], 534 vector=embedding, 535 payload=full_payload, 536 ) 537 ) 538 539 # Batch Upsert (sync call → asyncio.to_thread) 540 await asyncio.to_thread( 541 self.client.upsert, 542 collection_name=self.collection_name, 543 points=points, 544 ) 545 total_points += len(points) 546 547 logger.debug(f"Batch upserted {len(points)} points") 548 549 logger.info(f"Batch upsert completed - total={total_points}") 550 return {"success": True, "data": {"total_upserted": total_points}} 551 552 except Exception as e: 553 logger.error(f"Failed to batch upsert: {e}") 554 return { 555 "success": False, 556 "error": f"Batch upsert error: {str(e)}", 557 "error_code": "QDRANT_BATCH_UPSERT_ERROR", 558 }
バッチUpsert
Args: items: List of {"id": str, "content": str, "metadata": dict} batch_size: バッチサイズ
Returns: Dict with success status and total_upserted
560 async def search( 561 self, 562 query_text: str, 563 limit: int = 10, 564 score_threshold: float = 0.4, 565 filter_conditions: Optional[Dict[str, Any]] = None, 566 ef: Optional[int] = None, 567 ) -> Dict[str, Any]: 568 """ 569 セマンティック検索 570 571 クエリテキストに意味的に近いドキュメントを検索します。 572 フィルタ条件を指定して結果を絞り込むことも可能です。 573 574 Args: 575 query_text: 検索クエリ。自然言語で記述 576 limit: 結果数上限(デフォルト: 10) 577 score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) 578 filter_conditions: フィルタ条件 {"field": value} 形式(オプション) 579 ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる 580 581 Returns: 582 Dict[str, Any]: 検索結果 583 - success: bool 584 - data: { 585 "query": str, 586 "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], 587 "total_found": int, 588 "score_threshold": float 589 } 590 - error: str, error_code: str (失敗時) 591 592 Example: 593 >>> # 基本的な検索 594 >>> results = await manager.search( 595 ... query_text="プログラミング言語", 596 ... limit=5 597 ... ) 598 >>> for hit in results["data"]["results"]: 599 ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}") 600 >>> 601 >>> # フィルタ付き検索 602 >>> results = await manager.search( 603 ... query_text="データ分析", 604 ... limit=10, 605 ... score_threshold=0.5, 606 ... filter_conditions={"category": "programming"} 607 ... ) 608 >>> print(f"Found {results['data']['total_found']} results") 609 """ 610 await self.ensure_initialized() 611 612 try: 613 logger.debug( 614 f"Searching - query_len={len(query_text)}, limit={limit}" 615 ) 616 617 # Generate query embedding 618 query_vector = await self.embedding_generator.generate(query_text) 619 620 # Build filter 621 query_filter = None 622 if filter_conditions: 623 must_conditions = [ 624 FieldCondition( 625 key=key, 626 match=MatchValue(value=value), 627 ) 628 for key, value in filter_conditions.items() 629 ] 630 query_filter = Filter(must=must_conditions) 631 632 # Search params 633 search_params = SearchParams( 634 hnsw_ef=ef or (limit * 2), 635 exact=False, 636 ) 637 638 # Execute search (sync call → asyncio.to_thread) 639 results = await asyncio.to_thread( 640 self.client.query_points, 641 collection_name=self.collection_name, 642 query=query_vector, 643 query_filter=query_filter, 644 limit=limit, 645 score_threshold=score_threshold, 646 search_params=search_params, 647 with_payload=True, 648 ) 649 650 # Format results 651 formatted_results = [ 652 { 653 "point_id": hit.id, 654 "payload": hit.payload, 655 "score": hit.score, 656 "similarity": self._score_to_similarity(hit.score), 657 } 658 for hit in results.points 659 ] 660 661 logger.info( 662 f"Search completed - found={len(formatted_results)}, " 663 f"query_len={len(query_text)}" 664 ) 665 return { 666 "success": True, 667 "data": { 668 "query": query_text, 669 "results": formatted_results, 670 "total_found": len(formatted_results), 671 "score_threshold": score_threshold, 672 }, 673 } 674 675 except Exception as e: 676 logger.error(f"Failed to search: {e}") 677 return { 678 "success": False, 679 "error": f"Search error: {str(e)}", 680 "error_code": "QDRANT_SEARCH_ERROR", 681 }
セマンティック検索
クエリテキストに意味的に近いドキュメントを検索します。 フィルタ条件を指定して結果を絞り込むことも可能です。
Args: query_text: 検索クエリ。自然言語で記述 limit: 結果数上限(デフォルト: 10) score_threshold: スコア閾値 (Qdrant cosine: -1〜1、デフォルト: 0.4) filter_conditions: フィルタ条件 {"field": value} 形式(オプション) ef: HNSW検索深度 (デフォルト: limit * 2)。大きいほど精度が上がるが遅くなる
Returns: Dict[str, Any]: 検索結果 - success: bool - data: { "query": str, "results": [{"point_id": str, "payload": dict, "score": float, "similarity": float}], "total_found": int, "score_threshold": float } - error: str, error_code: str (失敗時)
Example:
基本的な検索
results = await manager.search( ... query_text="プログラミング言語", ... limit=5 ... ) for hit in results["data"]["results"]: ... print(f"Score: {hit['score']:.3f} - {hit['payload']['title']}")
フィルタ付き検索
results = await manager.search( ... query_text="データ分析", ... limit=10, ... score_threshold=0.5, ... filter_conditions={"category": "programming"} ... ) print(f"Found {results['data']['total_found']} results")
689 async def delete(self, point_ids: List[str]) -> Dict[str, Any]: 690 """ 691 ポイントを削除 692 693 指定したIDのベクトルデータを削除します。 694 695 Args: 696 point_ids: 削除するポイントIDのリスト 697 698 Returns: 699 Dict[str, Any]: 削除結果 700 - success: bool 701 - data: {"deleted_count": int} (成功時) 702 - error: str, error_code: str (失敗時) 703 704 Example: 705 >>> # 単一ポイントを削除 706 >>> result = await manager.delete(["doc-001"]) 707 >>> print(f"Deleted: {result['data']['deleted_count']}") 708 >>> 709 >>> # 複数ポイントを削除 710 >>> result = await manager.delete(["doc-001", "doc-002", "doc-003"]) 711 >>> if result["success"]: 712 ... print(f"Deleted {result['data']['deleted_count']} documents") 713 """ 714 await self.ensure_initialized() 715 716 try: 717 # Sync call → asyncio.to_thread 718 await asyncio.to_thread( 719 self.client.delete, 720 collection_name=self.collection_name, 721 points_selector=point_ids, 722 ) 723 logger.info(f"Deleted {len(point_ids)} points") 724 return {"success": True, "data": {"deleted_count": len(point_ids)}} 725 726 except Exception as e: 727 logger.error(f"Failed to delete points: {e}") 728 return { 729 "success": False, 730 "error": f"Delete error: {str(e)}", 731 "error_code": "QDRANT_DELETE_ERROR", 732 }
ポイントを削除
指定したIDのベクトルデータを削除します。
Args: point_ids: 削除するポイントIDのリスト
Returns: Dict[str, Any]: 削除結果 - success: bool - data: {"deleted_count": int} (成功時) - error: str, error_code: str (失敗時)
Example:
単一ポイントを削除
result = await manager.delete(["doc-001"]) print(f"Deleted: {result['data']['deleted_count']}")
複数ポイントを削除
result = await manager.delete(["doc-001", "doc-002", "doc-003"]) if result["success"]: ... print(f"Deleted {result['data']['deleted_count']} documents")
734 async def get_statistics(self) -> Dict[str, Any]: 735 """ 736 コレクション統計情報を取得 737 738 コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。 739 740 Returns: 741 Dict[str, Any]: 統計情報 742 - success: bool 743 - data: { 744 "total_objects": int, 745 "vectors_count": int, 746 "content_type_breakdown": dict, 747 "collection_name": str, 748 "status": str 749 } 750 - error: str, error_code: str (失敗時) 751 752 Example: 753 >>> stats = await manager.get_statistics() 754 >>> if stats["success"]: 755 ... data = stats["data"] 756 ... print(f"Total documents: {data['total_objects']}") 757 ... print(f"Status: {data['status']}") 758 ... print("Content types:") 759 ... for ctype, count in data["content_type_breakdown"].items(): 760 ... print(f" {ctype}: {count}") 761 """ 762 await self.ensure_initialized() 763 764 try: 765 # Sync calls → asyncio.to_thread 766 collection_info = await asyncio.to_thread( 767 self.client.get_collection, 768 self.collection_name, 769 ) 770 771 # Content type breakdown (sample 1000 points) 772 scroll_result = await asyncio.to_thread( 773 self.client.scroll, 774 collection_name=self.collection_name, 775 limit=1000, 776 with_payload=["content_type"], 777 with_vectors=False, 778 ) 779 780 content_type_breakdown = {} 781 for point in scroll_result[0]: 782 content_type = point.payload.get("content_type", "unknown") 783 content_type_breakdown[content_type] = ( 784 content_type_breakdown.get(content_type, 0) + 1 785 ) 786 787 stats = { 788 "total_objects": collection_info.points_count, 789 "vectors_count": collection_info.vectors_count, 790 "content_type_breakdown": content_type_breakdown, 791 "collection_name": self.collection_name, 792 "status": collection_info.status.value, 793 } 794 795 logger.info(f"Statistics retrieved - total={stats['total_objects']}") 796 return {"success": True, "data": stats} 797 798 except Exception as e: 799 logger.error(f"Failed to get statistics: {e}") 800 return { 801 "success": False, 802 "error": f"Statistics error: {str(e)}", 803 "error_code": "QDRANT_STATS_ERROR", 804 }
コレクション統計情報を取得
コレクション内のドキュメント数やコンテンツタイプの内訳などを取得します。
Returns: Dict[str, Any]: 統計情報 - success: bool - data: { "total_objects": int, "vectors_count": int, "content_type_breakdown": dict, "collection_name": str, "status": str } - error: str, error_code: str (失敗時)
Example:
stats = await manager.get_statistics() if stats["success"]: ... data = stats["data"] ... print(f"Total documents: {data['total_objects']}") ... print(f"Status: {data['status']}") ... print("Content types:") ... for ctype, count in data["content_type_breakdown"].items(): ... print(f" {ctype}: {count}")
806 async def delete_collection(self, collection_name: Optional[str] = None) -> Dict[str, Any]: 807 """ 808 コレクションを削除 809 810 Args: 811 collection_name: 削除するコレクション名(省略時は現在のコレクション) 812 813 Returns: 814 Dict[str, Any]: 削除結果 815 - success: bool 816 - data: {"collection_name": str} (成功時) 817 - error: str, error_code: str (失敗時) 818 819 Example: 820 >>> # 現在のコレクションを削除 821 >>> result = await manager.delete_collection() 822 >>> if result["success"]: 823 ... print(f"Deleted: {result['data']['collection_name']}") 824 825 >>> # 別のコレクションを削除 826 >>> result = await manager.delete_collection("old_collection") 827 """ 828 target_collection = collection_name or self.collection_name 829 830 try: 831 await asyncio.to_thread( 832 self.client.delete_collection, 833 target_collection, 834 ) 835 logger.info(f"Collection deleted: {target_collection}") 836 837 # 現在のコレクションを削除した場合は初期化状態をリセット 838 if target_collection == self.collection_name: 839 self._initialized = False 840 841 return {"success": True, "data": {"collection_name": target_collection}} 842 843 except Exception as e: 844 logger.error(f"Failed to delete collection: {e}") 845 return { 846 "success": False, 847 "error": f"Delete collection error: {str(e)}", 848 "error_code": "QDRANT_DELETE_COLLECTION_ERROR", 849 }
コレクションを削除
Args: collection_name: 削除するコレクション名(省略時は現在のコレクション)
Returns: Dict[str, Any]: 削除結果 - success: bool - data: {"collection_name": str} (成功時) - error: str, error_code: str (失敗時)
Example:
現在のコレクションを削除
result = await manager.delete_collection() if result["success"]: ... print(f"Deleted: {result['data']['collection_name']}")
>>> # 別のコレクションを削除 >>> result = await manager.delete_collection("old_collection")
851 async def collection_exists(self, collection_name: Optional[str] = None) -> bool: 852 """ 853 コレクションが存在するか確認 854 855 Args: 856 collection_name: 確認するコレクション名(省略時は現在のコレクション) 857 858 Returns: 859 bool: コレクションが存在する場合True 860 861 Example: 862 >>> if await manager.collection_exists(): 863 ... print("Collection exists") 864 >>> else: 865 ... print("Collection does not exist") 866 """ 867 target_collection = collection_name or self.collection_name 868 869 try: 870 collections_response = await asyncio.to_thread( 871 self.client.get_collections 872 ) 873 collection_names = [col.name for col in collections_response.collections] 874 return target_collection in collection_names 875 876 except Exception as e: 877 logger.error(f"Failed to check collection existence: {e}") 878 return False
コレクションが存在するか確認
Args: collection_name: 確認するコレクション名(省略時は現在のコレクション)
Returns: bool: コレクションが存在する場合True
Example:
if await manager.collection_exists(): ... print("Collection exists") else: ... print("Collection does not exist")
880 async def close(self) -> None: 881 """ 882 Close Qdrant client 883 884 Note: 885 例外が発生してもリソースを確実にクリーンアップします。 886 """ 887 try: 888 logger.info("Closing QdrantManager") 889 # Qdrantクライアントのクローズ処理(必要に応じて) 890 if self.client and hasattr(self.client, 'close'): 891 self.client.close() 892 except Exception as e: 893 logger.warning(f"Error closing Qdrant client: {e}") 894 finally: 895 self.client = None 896 self._initialized = False 897 logger.debug("QdrantManager closed")
Close Qdrant client
Note: 例外が発生してもリソースを確実にクリーンアップします。
66@dataclass 67class QdrantConfig: 68 """Qdrant設定 69 70 Example: 71 >>> # 辞書から作成 72 >>> config = QdrantConfig.from_dict({ 73 ... "url": "http://localhost:6333", 74 ... "collection_name": "my_collection", 75 ... "vector_size": 1536 76 ... }) 77 78 >>> # TOMLファイルから作成 79 >>> config = QdrantConfig.from_toml("config.toml", section="rag.qdrant") 80 """ 81 url: str 82 collection_name: str 83 vector_size: int = 1536 # text-embedding-ada-002 84 distance: str = "cosine" # cosine, euclid, dot 85 on_disk_vectors: bool = False 86 on_disk_payload: bool = True 87 hnsw_m: int = 16 88 hnsw_ef_construct: int = 256 89 # Payload Indexes(カスタマイズ可能) 90 payload_indexes: List[PayloadIndexConfig] = field(default_factory=lambda: [ 91 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 92 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 93 ]) 94 # CLIモード用認証トークン取得関数(QdrantClientのauth_token_providerに渡す) 95 # 設定時はプロキシ経由のためgRPCを無効化 96 auth_token_provider: Optional[Callable[[], str]] = None 97 98 @classmethod 99 def from_dict(cls, data: Dict[str, Any]) -> "QdrantConfig": 100 """辞書からQdrantConfigを作成 101 102 Args: 103 data: 設定辞書 104 105 Returns: 106 QdrantConfig instance 107 108 Example: 109 >>> config = QdrantConfig.from_dict({ 110 ... "url": "http://localhost:6333", 111 ... "collection_name": "my_collection", 112 ... "vector_size": 1536, 113 ... "distance": "cosine" 114 ... }) 115 """ 116 # payload_indexesがある場合は変換 117 payload_indexes = None 118 if "payload_indexes" in data: 119 payload_indexes = [ 120 PayloadIndexConfig.from_dict(idx) if isinstance(idx, dict) else idx 121 for idx in data["payload_indexes"] 122 ] 123 124 return cls( 125 url=data["url"], 126 collection_name=data["collection_name"], 127 vector_size=data.get("vector_size", 1536), 128 distance=data.get("distance", "cosine"), 129 on_disk_vectors=data.get("on_disk_vectors", False), 130 on_disk_payload=data.get("on_disk_payload", True), 131 hnsw_m=data.get("hnsw_m", 16), 132 hnsw_ef_construct=data.get("hnsw_ef_construct", 256), 133 payload_indexes=payload_indexes if payload_indexes else [ 134 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 135 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 136 ], 137 # auth_token_providerは辞書から設定不可(コールバック関数のため) 138 ) 139 140 @classmethod 141 def from_toml(cls, toml_path: str, section: str = "rag.qdrant") -> "QdrantConfig": 142 """TOMLファイルからQdrantConfigを作成 143 144 Args: 145 toml_path: TOMLファイルパス 146 section: セクション名(ドット区切りでネスト対応) 147 148 Returns: 149 QdrantConfig instance 150 151 Example: 152 >>> # config.toml の [rag.qdrant] セクションを読み込み 153 >>> config = QdrantConfig.from_toml("config.toml") 154 155 >>> # カスタムセクション 156 >>> config = QdrantConfig.from_toml("config.toml", section="vectordb") 157 """ 158 if not TOML_AVAILABLE: 159 raise ImportError("toml library is required: pip install toml") 160 161 with open(toml_path, "r") as f: 162 toml_data = toml.load(f) 163 164 # ネストされたセクションをドット区切りで辿る 165 data = toml_data 166 for key in section.split("."): 167 data = data[key] 168 169 return cls.from_dict(data)
Qdrant設定
Example:
辞書から作成
config = QdrantConfig.from_dict({ ... "url": "http://localhost:6333", ... "collection_name": "my_collection", ... "vector_size": 1536 ... })
>>> # TOMLファイルから作成 >>> config = QdrantConfig.from_toml("config.toml", section="rag.qdrant")
98 @classmethod 99 def from_dict(cls, data: Dict[str, Any]) -> "QdrantConfig": 100 """辞書からQdrantConfigを作成 101 102 Args: 103 data: 設定辞書 104 105 Returns: 106 QdrantConfig instance 107 108 Example: 109 >>> config = QdrantConfig.from_dict({ 110 ... "url": "http://localhost:6333", 111 ... "collection_name": "my_collection", 112 ... "vector_size": 1536, 113 ... "distance": "cosine" 114 ... }) 115 """ 116 # payload_indexesがある場合は変換 117 payload_indexes = None 118 if "payload_indexes" in data: 119 payload_indexes = [ 120 PayloadIndexConfig.from_dict(idx) if isinstance(idx, dict) else idx 121 for idx in data["payload_indexes"] 122 ] 123 124 return cls( 125 url=data["url"], 126 collection_name=data["collection_name"], 127 vector_size=data.get("vector_size", 1536), 128 distance=data.get("distance", "cosine"), 129 on_disk_vectors=data.get("on_disk_vectors", False), 130 on_disk_payload=data.get("on_disk_payload", True), 131 hnsw_m=data.get("hnsw_m", 16), 132 hnsw_ef_construct=data.get("hnsw_ef_construct", 256), 133 payload_indexes=payload_indexes if payload_indexes else [ 134 PayloadIndexConfig(field_name="content_type", field_schema="keyword"), 135 PayloadIndexConfig(field_name="created_at", field_schema="integer"), 136 ], 137 # auth_token_providerは辞書から設定不可(コールバック関数のため) 138 )
辞書からQdrantConfigを作成
Args: data: 設定辞書
Returns: QdrantConfig instance
Example:
config = QdrantConfig.from_dict({ ... "url": "http://localhost:6333", ... "collection_name": "my_collection", ... "vector_size": 1536, ... "distance": "cosine" ... })
140 @classmethod 141 def from_toml(cls, toml_path: str, section: str = "rag.qdrant") -> "QdrantConfig": 142 """TOMLファイルからQdrantConfigを作成 143 144 Args: 145 toml_path: TOMLファイルパス 146 section: セクション名(ドット区切りでネスト対応) 147 148 Returns: 149 QdrantConfig instance 150 151 Example: 152 >>> # config.toml の [rag.qdrant] セクションを読み込み 153 >>> config = QdrantConfig.from_toml("config.toml") 154 155 >>> # カスタムセクション 156 >>> config = QdrantConfig.from_toml("config.toml", section="vectordb") 157 """ 158 if not TOML_AVAILABLE: 159 raise ImportError("toml library is required: pip install toml") 160 161 with open(toml_path, "r") as f: 162 toml_data = toml.load(f) 163 164 # ネストされたセクションをドット区切りで辿る 165 data = toml_data 166 for key in section.split("."): 167 data = data[key] 168 169 return cls.from_dict(data)
TOMLファイルからQdrantConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: QdrantConfig instance
Example:
config.toml の [rag.qdrant] セクションを読み込み
config = QdrantConfig.from_toml("config.toml")
>>> # カスタムセクション >>> config = QdrantConfig.from_toml("config.toml", section="vectordb")
12class EventType(str, Enum): 13 """ 14 AgentLoopPublisher粒度のイベントタイプ定義 15 16 フロントエンドでのUI表示制御に使用される主要イベントタイプ 17 """ 18 19 # フェーズ系イベント 20 PHASE_START = "phase_start" 21 PROGRESS_UPDATE = "progress_update" 22 23 # AI思考プロセス 24 THOUGHT_MESSAGE = "thought_message" 25 26 # 完了系イベント 27 COMPLETION_SUCCESS = "completion_success" 28 COMPLETION_FAILURE = "completion_failure" 29 30 # 相互作用イベント 31 USER_INTERACTION_REQUIRED = "user_interaction_required" 32 UNEXPECTED_ERROR = "unexpected_error" 33 34 # HITL(Human-In-The-Loop)関連イベント 35 HITL_REQUIRED_BROWSER_VNC = "hitl_required_browser_vnc" # Pod環境: VNC経由でブラウザ介入 36 HITL_REQUIRED_BROWSER_CLI = "hitl_required_browser_cli" # CLI環境: ローカルブラウザ介入 37 HITL_COMPLETED = "hitl_completed" # HITL操作完了 38 39 # ファイル作成イベント 40 FILE_CREATED = "file_created" 41 42 # パーミッション関連イベント 43 PERMISSION_REQUEST = "permission_request" 44 PERMISSION_RESPONSE = "permission_response" 45 46 # 統一ツールイベント 47 TOOL_START = "tool_start" # ツール実行開始 48 TOOL_RESULT = "tool_result" # ツール実行結果 49 50 # 最終結果 51 FINAL_RESULT = "final_result"
AgentLoopPublisher粒度のイベントタイプ定義
フロントエンドでのUI表示制御に使用される主要イベントタイプ
54class SubEventType(str, Enum): 55 """ 56 ツール結果のサブイベントタイプ定義 57 58 フロントエンドでの表示出し分けに使用 59 例: search_webはWeb検索アイコン、file_editedはファイル編集アイコン 60 """ 61 62 # Web検索関連 63 SEARCH_WEB = "search_web" 64 65 # コマンド実行関連 66 COMMAND_EXECUTION = "command_execution" 67 68 # ファイル操作関連 69 FILE_OPERATION = "file_operation" 70 71 # ローカルアシスタント関連 72 LOCAL_ASSISTANT = "local_assistant" 73 74 # MCPツール関連 75 MCP_TOOL = "mcp_tool" 76 77 # OpenCode互換ツールイベント 78 FILE_EDITED = "file_edited" # write, edit, multiedit 79 FILE_READ = "file_read" # read 80 FILE_SEARCHED = "file_searched" # grep, glob, ls 81 BASH_EXECUTED = "bash_executed" # bash 82 WEB_FETCHED = "web_fetched" # webfetch 83 TASK_LAUNCHED = "task_launched" # task 84 TODO_UPDATED = "todo_updated" # todo 85 VIDEO_GENERATED = "video_generated" # video_gen 86 IMAGE_GENERATED = "image_generated" # generation_image 87 88 # Slide generation 89 SLIDE_CREATED = "slide_created" # save_html (slide preview) 90 91 # macOS Automation 92 MACOS_AUTOMATION = "macos_automation" # macOS UI/app automation tools
ツール結果のサブイベントタイプ定義
フロントエンドでの表示出し分けに使用 例: search_webはWeb検索アイコン、file_editedはファイル編集アイコン
21@dataclass 22class StreamingEvent: 23 """ 24 ストリーミングイベントデータ 25 26 リアルタイムストリーミング(SSE等)で使用される基本イベント構造 27 28 Attributes: 29 event_type: イベントタイプ(EventType enum) 30 execution_id: 実行ID 31 message: イベントメッセージ 32 timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) 33 metadata: 追加メタデータ(オプション) 34 sub_event_type: サブイベントタイプ(フロント表示用、オプション) 35 """ 36 37 event_type: EventType 38 execution_id: str 39 message: str 40 timestamp: float = field(default_factory=_current_timestamp) 41 metadata: Optional[Dict[str, Any]] = None 42 sub_event_type: Optional[SubEventType] = None 43 44 def to_dict(self) -> Dict[str, Any]: 45 """辞書形式に変換""" 46 result = { 47 "event_type": self.event_type.value, 48 "execution_id": self.execution_id, 49 "message": self.message, 50 "timestamp": self.timestamp, 51 } 52 if self.metadata is not None: 53 result["metadata"] = self.metadata 54 if self.sub_event_type is not None: 55 result["sub_event_type"] = self.sub_event_type.value 56 return result
ストリーミングイベントデータ
リアルタイムストリーミング(SSE等)で使用される基本イベント構造
Attributes: event_type: イベントタイプ(EventType enum) execution_id: 実行ID message: イベントメッセージ timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(フロント表示用、オプション)
44 def to_dict(self) -> Dict[str, Any]: 45 """辞書形式に変換""" 46 result = { 47 "event_type": self.event_type.value, 48 "execution_id": self.execution_id, 49 "message": self.message, 50 "timestamp": self.timestamp, 51 } 52 if self.metadata is not None: 53 result["metadata"] = self.metadata 54 if self.sub_event_type is not None: 55 result["sub_event_type"] = self.sub_event_type.value 56 return result
辞書形式に変換
59@dataclass 60class SequencedEvent: 61 """ 62 順序付きイベント 63 64 イベントの順序保証が必要な場合に使用 65 66 Attributes: 67 sequence: シーケンス番号(0から開始、負の値は不可) 68 event_type: イベントタイプ(EventType enum) 69 execution_id: 実行ID 70 message: イベントメッセージ 71 timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) 72 metadata: 追加メタデータ(オプション) 73 sub_event_type: サブイベントタイプ(フロント表示用、オプション) 74 """ 75 76 sequence: int 77 event_type: EventType 78 execution_id: str 79 message: str 80 timestamp: float = field(default_factory=_current_timestamp) 81 metadata: Optional[Dict[str, Any]] = None 82 sub_event_type: Optional[SubEventType] = None 83 84 def __post_init__(self): 85 """シーケンス番号のバリデーション""" 86 if self.sequence < 0: 87 raise ValueError(f"sequence must be >= 0, got {self.sequence}") 88 89 def to_streaming_event(self) -> StreamingEvent: 90 """StreamingEventに変換""" 91 return StreamingEvent( 92 event_type=self.event_type, 93 execution_id=self.execution_id, 94 message=self.message, 95 timestamp=self.timestamp, 96 metadata=self.metadata, 97 sub_event_type=self.sub_event_type, 98 ) 99 100 def to_dict(self) -> Dict[str, Any]: 101 """辞書形式に変換""" 102 result = { 103 "sequence": self.sequence, 104 "event_type": self.event_type.value, 105 "execution_id": self.execution_id, 106 "message": self.message, 107 "timestamp": self.timestamp, 108 } 109 if self.metadata is not None: 110 result["metadata"] = self.metadata 111 if self.sub_event_type is not None: 112 result["sub_event_type"] = self.sub_event_type.value 113 return result
順序付きイベント
イベントの順序保証が必要な場合に使用
Attributes: sequence: シーケンス番号(0から開始、負の値は不可) event_type: イベントタイプ(EventType enum) execution_id: 実行ID message: イベントメッセージ timestamp: イベント発生時刻(Unix timestamp、省略時は現在時刻) metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(フロント表示用、オプション)
89 def to_streaming_event(self) -> StreamingEvent: 90 """StreamingEventに変換""" 91 return StreamingEvent( 92 event_type=self.event_type, 93 execution_id=self.execution_id, 94 message=self.message, 95 timestamp=self.timestamp, 96 metadata=self.metadata, 97 sub_event_type=self.sub_event_type, 98 )
StreamingEventに変換
100 def to_dict(self) -> Dict[str, Any]: 101 """辞書形式に変換""" 102 result = { 103 "sequence": self.sequence, 104 "event_type": self.event_type.value, 105 "execution_id": self.execution_id, 106 "message": self.message, 107 "timestamp": self.timestamp, 108 } 109 if self.metadata is not None: 110 result["metadata"] = self.metadata 111 if self.sub_event_type is not None: 112 result["sub_event_type"] = self.sub_event_type.value 113 return result
辞書形式に変換
116@dataclass 117class ExecutionMessage: 118 """ 119 実行メッセージデータ(DB保存用) 120 121 エージェント実行中のイベントをデータベースに保存するための構造体。 122 マーケットプレイスUIとの連携に使用。 123 124 Attributes: 125 user_id: ユーザーID 126 conversation_id: 会話ID 127 message_id: メッセージID 128 chunk_type: チャンクタイプ(content, reasoning_content, task_created等) 129 content_data: コンテンツデータ(JSON) 130 131 Example: 132 >>> msg = ExecutionMessage( 133 ... user_id="user-123", 134 ... conversation_id="conv-456", 135 ... message_id="msg-789", 136 ... chunk_type="content", 137 ... content_data={"content": "処理中です..."} 138 ... ) 139 >>> await data_access.insert("execution_messages", msg.to_dict()) 140 """ 141 142 user_id: str 143 conversation_id: str 144 message_id: str 145 chunk_type: str 146 content_data: Dict[str, Any] 147 148 def to_dict(self) -> Dict[str, Any]: 149 """DB挿入用の辞書形式に変換""" 150 return { 151 "user_id": self.user_id, 152 "conversation_id": self.conversation_id, 153 "message_id": self.message_id, 154 "chunk_type": self.chunk_type, 155 "content_data": self.content_data, 156 } 157 158 @classmethod 159 def from_streaming_event( 160 cls, 161 event: "StreamingEvent", 162 user_id: str, 163 conversation_id: str, 164 message_id: str, 165 chunk_type: Optional[str] = None, 166 ) -> "ExecutionMessage": 167 """StreamingEventからExecutionMessageを作成 168 169 Args: 170 event: ストリーミングイベント 171 user_id: ユーザーID 172 conversation_id: 会話ID 173 message_id: メッセージID 174 chunk_type: チャンクタイプ(省略時はevent_typeから推定) 175 176 Returns: 177 ExecutionMessage instance 178 179 Example: 180 >>> msg = ExecutionMessage.from_streaming_event( 181 ... event=event, 182 ... user_id="user-123", 183 ... conversation_id="conv-456", 184 ... message_id="msg-789" 185 ... ) 186 """ 187 # チャンクタイプの推定(省略時) 188 if chunk_type is None: 189 chunk_type_map = { 190 EventType.PHASE_START: "content", 191 EventType.PROGRESS_UPDATE: "content", 192 EventType.THOUGHT_MESSAGE: "reasoning_content", 193 EventType.TOOL_START: "task_created", 194 EventType.TOOL_RESULT: "task_created", 195 EventType.FILE_CREATED: "task_created", 196 EventType.COMPLETION_SUCCESS: "content", 197 EventType.COMPLETION_FAILURE: "content", 198 } 199 chunk_type = chunk_type_map.get(event.event_type, "content") 200 201 content_data = { 202 "content": event.message, 203 } 204 if event.metadata: 205 content_data.update(event.metadata) 206 207 return cls( 208 user_id=user_id, 209 conversation_id=conversation_id, 210 message_id=message_id, 211 chunk_type=chunk_type, 212 content_data=content_data, 213 )
実行メッセージデータ(DB保存用)
エージェント実行中のイベントをデータベースに保存するための構造体。 マーケットプレイスUIとの連携に使用。
Attributes: user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID chunk_type: チャンクタイプ(content, reasoning_content, task_created等) content_data: コンテンツデータ(JSON)
Example:
msg = ExecutionMessage( ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... chunk_type="content", ... content_data={"content": "処理中です..."} ... ) await data_access.insert("execution_messages", msg.to_dict())
148 def to_dict(self) -> Dict[str, Any]: 149 """DB挿入用の辞書形式に変換""" 150 return { 151 "user_id": self.user_id, 152 "conversation_id": self.conversation_id, 153 "message_id": self.message_id, 154 "chunk_type": self.chunk_type, 155 "content_data": self.content_data, 156 }
DB挿入用の辞書形式に変換
158 @classmethod 159 def from_streaming_event( 160 cls, 161 event: "StreamingEvent", 162 user_id: str, 163 conversation_id: str, 164 message_id: str, 165 chunk_type: Optional[str] = None, 166 ) -> "ExecutionMessage": 167 """StreamingEventからExecutionMessageを作成 168 169 Args: 170 event: ストリーミングイベント 171 user_id: ユーザーID 172 conversation_id: 会話ID 173 message_id: メッセージID 174 chunk_type: チャンクタイプ(省略時はevent_typeから推定) 175 176 Returns: 177 ExecutionMessage instance 178 179 Example: 180 >>> msg = ExecutionMessage.from_streaming_event( 181 ... event=event, 182 ... user_id="user-123", 183 ... conversation_id="conv-456", 184 ... message_id="msg-789" 185 ... ) 186 """ 187 # チャンクタイプの推定(省略時) 188 if chunk_type is None: 189 chunk_type_map = { 190 EventType.PHASE_START: "content", 191 EventType.PROGRESS_UPDATE: "content", 192 EventType.THOUGHT_MESSAGE: "reasoning_content", 193 EventType.TOOL_START: "task_created", 194 EventType.TOOL_RESULT: "task_created", 195 EventType.FILE_CREATED: "task_created", 196 EventType.COMPLETION_SUCCESS: "content", 197 EventType.COMPLETION_FAILURE: "content", 198 } 199 chunk_type = chunk_type_map.get(event.event_type, "content") 200 201 content_data = { 202 "content": event.message, 203 } 204 if event.metadata: 205 content_data.update(event.metadata) 206 207 return cls( 208 user_id=user_id, 209 conversation_id=conversation_id, 210 message_id=message_id, 211 chunk_type=chunk_type, 212 content_data=content_data, 213 )
StreamingEventからExecutionMessageを作成
Args: event: ストリーミングイベント user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID chunk_type: チャンクタイプ(省略時はevent_typeから推定)
Returns: ExecutionMessage instance
Example:
msg = ExecutionMessage.from_streaming_event( ... event=event, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789" ... )
41class EventEmitter: 42 """ 43 イベント発行クライアント 44 45 非同期イベントキューを使用してイベントを発行・消費します。 46 シーケンス番号による順序保証を提供します。 47 48 Features: 49 - 非同期イベント発行(emit_event) 50 - シーケンス番号による順序保証 51 - 非同期イベント消費(consume_events) 52 - カスタムイベントハンドラー対応 53 54 Example: 55 >>> # 基本的な使い方 56 >>> emitter = EventEmitter(execution_id="exec-123") 57 >>> 58 >>> # イベント発行 59 >>> await emitter.emit_event( 60 ... EventType.PHASE_START, 61 ... "処理を開始します" 62 ... ) 63 >>> 64 >>> # イベント消費 65 >>> async for chunk in emitter.consume_events(): 66 ... print(chunk) 67 >>> 68 >>> # 完了通知 69 >>> await emitter.emit_event( 70 ... EventType.COMPLETION_SUCCESS, 71 ... "処理が完了しました" 72 ... ) 73 74 Note: 75 - emit_eventとconsume_eventsは非同期で実行してください 76 - 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)発行後、 77 consume_eventsは自動的に終了します 78 """ 79 80 def __init__( 81 self, 82 execution_id: str, 83 handler: Optional[EventHandler] = None, 84 ): 85 """ 86 EventEmitterを初期化 87 88 Args: 89 execution_id: 実行ID(イベントの識別に使用) 90 handler: イベントハンドラー(オプション) 91 指定しない場合、デフォルトのJSON変換が使用されます 92 93 Example: 94 >>> emitter = EventEmitter(execution_id="exec-123") 95 96 >>> # カスタムハンドラー付き 97 >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: 98 ... return f"data: {event.to_dict()}\\n\\n" 99 >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler) 100 """ 101 self.execution_id = execution_id 102 self._handler = handler 103 self._event_queue: asyncio.Queue[SequencedEvent] = asyncio.Queue() 104 self._sequence_number = 0 105 self._is_completed = False 106 self._completion_event = asyncio.Event() 107 108 logger.debug(f"[{execution_id}] EventEmitter initialized") 109 110 @property 111 def is_completed(self) -> bool: 112 """完了状態を取得""" 113 return self._is_completed 114 115 @property 116 def sequence_number(self) -> int: 117 """現在のシーケンス番号を取得""" 118 return self._sequence_number 119 120 async def emit_event( 121 self, 122 event_type: EventType, 123 message: str, 124 metadata: Optional[Dict[str, Any]] = None, 125 sub_event_type: Optional[SubEventType] = None, 126 ) -> None: 127 """ 128 イベントを発行 129 130 Args: 131 event_type: イベントタイプ 132 message: イベントメッセージ 133 metadata: 追加メタデータ(オプション) 134 sub_event_type: サブイベントタイプ(オプション) 135 136 Example: 137 >>> await emitter.emit_event( 138 ... EventType.THOUGHT_MESSAGE, 139 ... "処理中です...", 140 ... metadata={"progress": 50} 141 ... ) 142 143 Note: 144 完了イベント発行後にイベントを発行しようとすると、 145 警告ログが出力され、イベントは無視されます。 146 """ 147 if self._is_completed: 148 logger.warning( 149 f"[{self.execution_id}] Event emission after completion: {event_type}" 150 ) 151 return 152 153 sequenced_event = SequencedEvent( 154 sequence=self._sequence_number, 155 event_type=event_type, 156 execution_id=self.execution_id, 157 message=message, 158 timestamp=time.time(), 159 metadata=metadata, 160 sub_event_type=sub_event_type, 161 ) 162 self._sequence_number += 1 163 164 try: 165 await self._event_queue.put(sequenced_event) 166 logger.info( 167 f"[{self.execution_id}] Event queued: {event_type.value} " 168 f"(seq={sequenced_event.sequence})" 169 ) 170 except Exception as e: 171 logger.error( 172 f"[{self.execution_id}] Failed to queue event {event_type}: {e}" 173 ) 174 raise 175 176 async def emit( 177 self, 178 event_type: EventType, 179 message: str, 180 metadata: Optional[Dict[str, Any]] = None, 181 sub_event_type: Optional[SubEventType] = None, 182 ) -> None: 183 """emit_eventのエイリアス(短縮形)""" 184 await self.emit_event(event_type, message, metadata, sub_event_type) 185 186 async def consume_events( 187 self, 188 timeout: float = 0.1, 189 ) -> AsyncGenerator[str, None]: 190 """ 191 イベントを順序通りに消費してストリーミングチャンクを生成 192 193 Args: 194 timeout: イベント待機タイムアウト(秒) 195 196 Yields: 197 str: ストリーミングチャンク(ハンドラーの戻り値) 198 199 Example: 200 >>> async for chunk in emitter.consume_events(): 201 ... await response.write(chunk) 202 203 Note: 204 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 205 キューが空になった時点で自動的に終了します。 206 """ 207 logger.info(f"[{self.execution_id}] Starting event consumption") 208 209 while True: 210 try: 211 event = await asyncio.wait_for( 212 self._event_queue.get(), 213 timeout=timeout, 214 ) 215 216 logger.debug( 217 f"[{self.execution_id}] Processing event: {event.event_type.value} " 218 f"(seq={event.sequence})" 219 ) 220 221 # イベントをストリーミングチャンクに変換 222 chunk = await self._process_event(event) 223 if chunk: 224 yield chunk 225 226 # 完了イベントを処理したら完了状態に設定 227 if event.event_type in [ 228 EventType.COMPLETION_SUCCESS, 229 EventType.COMPLETION_FAILURE, 230 ]: 231 self._completion_event.set() 232 self._is_completed = True 233 logger.info( 234 f"[{self.execution_id}] Completion event processed" 235 ) 236 237 except asyncio.TimeoutError: 238 # 完了済みでキューが空なら終了 239 if self._completion_event.is_set() and self._event_queue.empty(): 240 logger.info(f"[{self.execution_id}] No more events, terminating") 241 break 242 continue 243 except Exception as e: 244 logger.error(f"[{self.execution_id}] Error consuming events: {e}") 245 break 246 247 logger.info(f"[{self.execution_id}] Event consumption completed") 248 249 async def _process_event(self, event: SequencedEvent) -> Optional[str]: 250 """イベントをストリーミング用に処理""" 251 streaming_event = event.to_streaming_event() 252 253 if self._handler: 254 try: 255 return await self._handler(streaming_event) 256 except Exception as e: 257 logger.error( 258 f"[{self.execution_id}] Error in event handler: {e}" 259 ) 260 return None 261 else: 262 # デフォルト: SSE形式で返す 263 import json 264 return f"data: {json.dumps(streaming_event.to_dict())}\n\n" 265 266 def mark_completed(self) -> None: 267 """ 268 手動で完了状態に設定 269 270 Note: 271 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 272 自動的に完了状態になります。 273 このメソッドは特殊なケースでのみ使用してください。 274 """ 275 self._is_completed = True 276 self._completion_event.set() 277 logger.info(f"[{self.execution_id}] Manually marked as completed") 278 279 async def cleanup(self) -> None: 280 """ 281 リソースのクリーンアップ 282 283 キューに残っているイベントを破棄し、状態をリセットします。 284 285 Example: 286 >>> await emitter.cleanup() 287 """ 288 # キューをクリア 289 while not self._event_queue.empty(): 290 try: 291 self._event_queue.get_nowait() 292 except asyncio.QueueEmpty: 293 break 294 295 self._is_completed = True 296 self._completion_event.set() 297 logger.info(f"[{self.execution_id}] EventEmitter cleaned up")
イベント発行クライアント
非同期イベントキューを使用してイベントを発行・消費します。 シーケンス番号による順序保証を提供します。
Features: - 非同期イベント発行(emit_event) - シーケンス番号による順序保証 - 非同期イベント消費(consume_events) - カスタムイベントハンドラー対応
Example:
基本的な使い方
emitter = EventEmitter(execution_id="exec-123")
イベント発行
await emitter.emit_event( ... EventType.PHASE_START, ... "処理を開始します" ... )
イベント消費
async for chunk in emitter.consume_events(): ... print(chunk)
完了通知
await emitter.emit_event( ... EventType.COMPLETION_SUCCESS, ... "処理が完了しました" ... )
Note: - emit_eventとconsume_eventsは非同期で実行してください - 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)発行後、 consume_eventsは自動的に終了します
80 def __init__( 81 self, 82 execution_id: str, 83 handler: Optional[EventHandler] = None, 84 ): 85 """ 86 EventEmitterを初期化 87 88 Args: 89 execution_id: 実行ID(イベントの識別に使用) 90 handler: イベントハンドラー(オプション) 91 指定しない場合、デフォルトのJSON変換が使用されます 92 93 Example: 94 >>> emitter = EventEmitter(execution_id="exec-123") 95 96 >>> # カスタムハンドラー付き 97 >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: 98 ... return f"data: {event.to_dict()}\\n\\n" 99 >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler) 100 """ 101 self.execution_id = execution_id 102 self._handler = handler 103 self._event_queue: asyncio.Queue[SequencedEvent] = asyncio.Queue() 104 self._sequence_number = 0 105 self._is_completed = False 106 self._completion_event = asyncio.Event() 107 108 logger.debug(f"[{execution_id}] EventEmitter initialized")
EventEmitterを初期化
Args: execution_id: 実行ID(イベントの識別に使用) handler: イベントハンドラー(オプション) 指定しない場合、デフォルトのJSON変換が使用されます
Example:
emitter = EventEmitter(execution_id="exec-123")
>>> # カスタムハンドラー付き >>> async def custom_handler(event: StreamingEvent) -> Optional[str]: ... return f"data: {event.to_dict()}\n\n" >>> emitter = EventEmitter(execution_id="exec-123", handler=custom_handler)
115 @property 116 def sequence_number(self) -> int: 117 """現在のシーケンス番号を取得""" 118 return self._sequence_number
現在のシーケンス番号を取得
120 async def emit_event( 121 self, 122 event_type: EventType, 123 message: str, 124 metadata: Optional[Dict[str, Any]] = None, 125 sub_event_type: Optional[SubEventType] = None, 126 ) -> None: 127 """ 128 イベントを発行 129 130 Args: 131 event_type: イベントタイプ 132 message: イベントメッセージ 133 metadata: 追加メタデータ(オプション) 134 sub_event_type: サブイベントタイプ(オプション) 135 136 Example: 137 >>> await emitter.emit_event( 138 ... EventType.THOUGHT_MESSAGE, 139 ... "処理中です...", 140 ... metadata={"progress": 50} 141 ... ) 142 143 Note: 144 完了イベント発行後にイベントを発行しようとすると、 145 警告ログが出力され、イベントは無視されます。 146 """ 147 if self._is_completed: 148 logger.warning( 149 f"[{self.execution_id}] Event emission after completion: {event_type}" 150 ) 151 return 152 153 sequenced_event = SequencedEvent( 154 sequence=self._sequence_number, 155 event_type=event_type, 156 execution_id=self.execution_id, 157 message=message, 158 timestamp=time.time(), 159 metadata=metadata, 160 sub_event_type=sub_event_type, 161 ) 162 self._sequence_number += 1 163 164 try: 165 await self._event_queue.put(sequenced_event) 166 logger.info( 167 f"[{self.execution_id}] Event queued: {event_type.value} " 168 f"(seq={sequenced_event.sequence})" 169 ) 170 except Exception as e: 171 logger.error( 172 f"[{self.execution_id}] Failed to queue event {event_type}: {e}" 173 ) 174 raise
イベントを発行
Args: event_type: イベントタイプ message: イベントメッセージ metadata: 追加メタデータ(オプション) sub_event_type: サブイベントタイプ(オプション)
Example:
await emitter.emit_event( ... EventType.THOUGHT_MESSAGE, ... "処理中です...", ... metadata={"progress": 50} ... )
Note: 完了イベント発行後にイベントを発行しようとすると、 警告ログが出力され、イベントは無視されます。
176 async def emit( 177 self, 178 event_type: EventType, 179 message: str, 180 metadata: Optional[Dict[str, Any]] = None, 181 sub_event_type: Optional[SubEventType] = None, 182 ) -> None: 183 """emit_eventのエイリアス(短縮形)""" 184 await self.emit_event(event_type, message, metadata, sub_event_type)
emit_eventのエイリアス(短縮形)
186 async def consume_events( 187 self, 188 timeout: float = 0.1, 189 ) -> AsyncGenerator[str, None]: 190 """ 191 イベントを順序通りに消費してストリーミングチャンクを生成 192 193 Args: 194 timeout: イベント待機タイムアウト(秒) 195 196 Yields: 197 str: ストリーミングチャンク(ハンドラーの戻り値) 198 199 Example: 200 >>> async for chunk in emitter.consume_events(): 201 ... await response.write(chunk) 202 203 Note: 204 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 205 キューが空になった時点で自動的に終了します。 206 """ 207 logger.info(f"[{self.execution_id}] Starting event consumption") 208 209 while True: 210 try: 211 event = await asyncio.wait_for( 212 self._event_queue.get(), 213 timeout=timeout, 214 ) 215 216 logger.debug( 217 f"[{self.execution_id}] Processing event: {event.event_type.value} " 218 f"(seq={event.sequence})" 219 ) 220 221 # イベントをストリーミングチャンクに変換 222 chunk = await self._process_event(event) 223 if chunk: 224 yield chunk 225 226 # 完了イベントを処理したら完了状態に設定 227 if event.event_type in [ 228 EventType.COMPLETION_SUCCESS, 229 EventType.COMPLETION_FAILURE, 230 ]: 231 self._completion_event.set() 232 self._is_completed = True 233 logger.info( 234 f"[{self.execution_id}] Completion event processed" 235 ) 236 237 except asyncio.TimeoutError: 238 # 完了済みでキューが空なら終了 239 if self._completion_event.is_set() and self._event_queue.empty(): 240 logger.info(f"[{self.execution_id}] No more events, terminating") 241 break 242 continue 243 except Exception as e: 244 logger.error(f"[{self.execution_id}] Error consuming events: {e}") 245 break 246 247 logger.info(f"[{self.execution_id}] Event consumption completed")
イベントを順序通りに消費してストリーミングチャンクを生成
Args: timeout: イベント待機タイムアウト(秒)
Yields: str: ストリーミングチャンク(ハンドラーの戻り値)
Example:
async for chunk in emitter.consume_events(): ... await response.write(chunk)
Note: 完了イベント(COMPLETION_SUCCESS/COMPLETION_FAILURE)を受信し、 キューが空になった時点で自動的に終了します。
266 def mark_completed(self) -> None: 267 """ 268 手動で完了状態に設定 269 270 Note: 271 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 272 自動的に完了状態になります。 273 このメソッドは特殊なケースでのみ使用してください。 274 """ 275 self._is_completed = True 276 self._completion_event.set() 277 logger.info(f"[{self.execution_id}] Manually marked as completed")
手動で完了状態に設定
Note: 通常はCOMPLETION_SUCCESS/COMPLETION_FAILUREイベントで 自動的に完了状態になります。 このメソッドは特殊なケースでのみ使用してください。
279 async def cleanup(self) -> None: 280 """ 281 リソースのクリーンアップ 282 283 キューに残っているイベントを破棄し、状態をリセットします。 284 285 Example: 286 >>> await emitter.cleanup() 287 """ 288 # キューをクリア 289 while not self._event_queue.empty(): 290 try: 291 self._event_queue.get_nowait() 292 except asyncio.QueueEmpty: 293 break 294 295 self._is_completed = True 296 self._completion_event.set() 297 logger.info(f"[{self.execution_id}] EventEmitter cleaned up")
リソースのクリーンアップ
キューに残っているイベントを破棄し、状態をリセットします。
Example:
await emitter.cleanup()
25class EventHandler(Protocol): 26 """イベントハンドラープロトコル 27 28 イベント受信時のコールバック関数の型定義 29 30 Example: 31 >>> async def my_handler(event: StreamingEvent) -> Optional[str]: 32 ... print(f"Received: {event.message}") 33 ... return f"data: {event.to_dict()}" 34 """ 35 36 async def __call__(self, event: StreamingEvent) -> Optional[str]: 37 """イベントを処理してストリーミングチャンクを返す""" 38 ...
イベントハンドラープロトコル
イベント受信時のコールバック関数の型定義
Example:
async def my_handler(event: StreamingEvent) -> Optional[str]: ... print(f"Received: {event.message}") ... return f"data: {event.to_dict()}"
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
300def create_sse_handler() -> EventHandler: 301 """ 302 SSE(Server-Sent Events)形式のイベントハンドラーを作成 303 304 Returns: 305 EventHandler: SSE形式でイベントを変換するハンドラー 306 307 Example: 308 >>> emitter = EventEmitter( 309 ... execution_id="exec-123", 310 ... handler=create_sse_handler() 311 ... ) 312 """ 313 import json 314 315 async def sse_handler(event: StreamingEvent) -> Optional[str]: 316 return f"data: {json.dumps(event.to_dict())}\n\n" 317 318 return sse_handler
SSE(Server-Sent Events)形式のイベントハンドラーを作成
Returns: EventHandler: SSE形式でイベントを変換するハンドラー
Example:
emitter = EventEmitter( ... execution_id="exec-123", ... handler=create_sse_handler() ... )
321def create_json_handler() -> EventHandler: 322 """ 323 JSON形式のイベントハンドラーを作成 324 325 Returns: 326 EventHandler: JSON形式でイベントを変換するハンドラー 327 328 Example: 329 >>> emitter = EventEmitter( 330 ... execution_id="exec-123", 331 ... handler=create_json_handler() 332 ... ) 333 """ 334 import json 335 336 async def json_handler(event: StreamingEvent) -> Optional[str]: 337 return json.dumps(event.to_dict()) 338 339 return json_handler
JSON形式のイベントハンドラーを作成
Returns: EventHandler: JSON形式でイベントを変換するハンドラー
Example:
emitter = EventEmitter( ... execution_id="exec-123", ... handler=create_json_handler() ... )
45class DatabaseEventHandler: 46 """ 47 データベースにイベントを保存するハンドラー 48 49 SDKのDataAccessを使用してexecution_messagesテーブルにイベントを保存。 50 マーケットプレイスUIとの連携に使用。 51 52 Features: 53 - StreamingEventをExecutionMessageに変換 54 - 汎用DataAccessによるDB保存 55 - テーブル名のカスタマイズ対応 56 57 Example: 58 >>> from agenticstar_platform.db import DataAccess, PostgreSQLConfig 59 >>> 60 >>> config = PostgreSQLConfig.from_env() 61 >>> async with DataAccess(config) as data_access: 62 ... handler = DatabaseEventHandler( 63 ... data_access=data_access, 64 ... user_id="user-123", 65 ... conversation_id="conv-456", 66 ... message_id="msg-789", 67 ... ) 68 ... await handler(event) 69 70 Note: 71 テーブルスキーマ(execution_messages): 72 - user_id: VARCHAR 73 - conversation_id: VARCHAR 74 - message_id: VARCHAR 75 - chunk_type: VARCHAR 76 - content_data: JSONB 77 - created_at: TIMESTAMP (自動設定) 78 """ 79 80 def __init__( 81 self, 82 data_access: Any, # DataAccess型だが循環インポート回避のためAny 83 user_id: str, 84 conversation_id: str, 85 message_id: str, 86 table_name: str = "execution_messages", 87 chunk_type_map: Optional[Dict[EventType, str]] = None, 88 ): 89 """ 90 DatabaseEventHandlerを初期化 91 92 Args: 93 data_access: SDKのDataAccessインスタンス 94 user_id: ユーザーID 95 conversation_id: 会話ID 96 message_id: メッセージID 97 table_name: 保存先テーブル名(デフォルト: execution_messages) 98 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 99 100 Example: 101 >>> handler = DatabaseEventHandler( 102 ... data_access=data_access, 103 ... user_id="user-123", 104 ... conversation_id="conv-456", 105 ... message_id="msg-789", 106 ... ) 107 """ 108 self.data_access = data_access 109 self.user_id = user_id 110 self.conversation_id = conversation_id 111 self.message_id = message_id 112 self.table_name = table_name 113 114 # デフォルトのchunk_typeマッピング 115 self._chunk_type_map = chunk_type_map or { 116 EventType.PHASE_START: "content", 117 EventType.PROGRESS_UPDATE: "content", 118 EventType.THOUGHT_MESSAGE: "reasoning_content", 119 EventType.TOOL_START: "task_created", 120 EventType.TOOL_RESULT: "task_created", 121 EventType.FILE_CREATED: "task_created", 122 EventType.USER_INTERACTION_REQUIRED: "content", 123 EventType.COMPLETION_SUCCESS: "content", 124 EventType.COMPLETION_FAILURE: "content", 125 EventType.HITL_REQUIRED_BROWSER_VNC: "vnc_hitl_event", 126 EventType.HITL_COMPLETED: "vnc_hitl_event", 127 EventType.FINAL_RESULT: None, 128 } 129 130 logger.debug( 131 f"DatabaseEventHandler initialized: table={table_name}, " 132 f"user_id={user_id}, conversation_id={conversation_id}" 133 ) 134 135 async def __call__(self, event: StreamingEvent) -> Optional[str]: 136 """ 137 イベントをデータベースに保存 138 139 Args: 140 event: ストリーミングイベント 141 142 Returns: 143 None(DBハンドラーはストリーミングチャンクを返さない) 144 """ 145 try: 146 # chunk_typeを決定(Noneの場合はスキップ) 147 chunk_type = self._chunk_type_map.get(event.event_type, "content") 148 if chunk_type is None: 149 logger.debug(f"Event skipped (chunk_type=None): {event.event_type.value}") 150 return None 151 152 # ExecutionMessageを作成 153 exec_msg = ExecutionMessage.from_streaming_event( 154 event=event, 155 user_id=self.user_id, 156 conversation_id=self.conversation_id, 157 message_id=self.message_id, 158 chunk_type=chunk_type, 159 ) 160 161 # DataAccessでDB保存 162 result = await self.data_access.insert( 163 self.table_name, 164 exec_msg.to_dict(), 165 ) 166 167 if not result.get("success"): 168 logger.error( 169 f"Failed to save event to database: {result.get('error')}" 170 ) 171 else: 172 logger.debug( 173 f"Event saved to database: {event.event_type.value}" 174 ) 175 176 return None 177 178 except Exception as e: 179 logger.error(f"Error saving event to database: {e}") 180 return None
データベースにイベントを保存するハンドラー
SDKのDataAccessを使用してexecution_messagesテーブルにイベントを保存。 マーケットプレイスUIとの連携に使用。
Features: - StreamingEventをExecutionMessageに変換 - 汎用DataAccessによるDB保存 - テーブル名のカスタマイズ対応
Example:
from agenticstar_platform.db import DataAccess, PostgreSQLConfig
config = PostgreSQLConfig.from_env() async with DataAccess(config) as data_access: ... handler = DatabaseEventHandler( ... data_access=data_access, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... ) ... await handler(event)
Note: テーブルスキーマ(execution_messages): - user_id: VARCHAR - conversation_id: VARCHAR - message_id: VARCHAR - chunk_type: VARCHAR - content_data: JSONB - created_at: TIMESTAMP (自動設定)
80 def __init__( 81 self, 82 data_access: Any, # DataAccess型だが循環インポート回避のためAny 83 user_id: str, 84 conversation_id: str, 85 message_id: str, 86 table_name: str = "execution_messages", 87 chunk_type_map: Optional[Dict[EventType, str]] = None, 88 ): 89 """ 90 DatabaseEventHandlerを初期化 91 92 Args: 93 data_access: SDKのDataAccessインスタンス 94 user_id: ユーザーID 95 conversation_id: 会話ID 96 message_id: メッセージID 97 table_name: 保存先テーブル名(デフォルト: execution_messages) 98 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 99 100 Example: 101 >>> handler = DatabaseEventHandler( 102 ... data_access=data_access, 103 ... user_id="user-123", 104 ... conversation_id="conv-456", 105 ... message_id="msg-789", 106 ... ) 107 """ 108 self.data_access = data_access 109 self.user_id = user_id 110 self.conversation_id = conversation_id 111 self.message_id = message_id 112 self.table_name = table_name 113 114 # デフォルトのchunk_typeマッピング 115 self._chunk_type_map = chunk_type_map or { 116 EventType.PHASE_START: "content", 117 EventType.PROGRESS_UPDATE: "content", 118 EventType.THOUGHT_MESSAGE: "reasoning_content", 119 EventType.TOOL_START: "task_created", 120 EventType.TOOL_RESULT: "task_created", 121 EventType.FILE_CREATED: "task_created", 122 EventType.USER_INTERACTION_REQUIRED: "content", 123 EventType.COMPLETION_SUCCESS: "content", 124 EventType.COMPLETION_FAILURE: "content", 125 EventType.HITL_REQUIRED_BROWSER_VNC: "vnc_hitl_event", 126 EventType.HITL_COMPLETED: "vnc_hitl_event", 127 EventType.FINAL_RESULT: None, 128 } 129 130 logger.debug( 131 f"DatabaseEventHandler initialized: table={table_name}, " 132 f"user_id={user_id}, conversation_id={conversation_id}" 133 )
DatabaseEventHandlerを初期化
Args: data_access: SDKのDataAccessインスタンス user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID table_name: 保存先テーブル名(デフォルト: execution_messages) chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション)
Example:
handler = DatabaseEventHandler( ... data_access=data_access, ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... )
183class WebhookEventHandler: 184 """ 185 Webhookにイベントを送信するハンドラー 186 187 HTTP POSTでイベントをWebhookエンドポイントに送信。 188 AgenticStar API等との連携に使用。 189 190 Features: 191 - 非同期HTTP POST送信 192 - カスタムヘッダー対応 193 - タイムアウト設定 194 - メッセージフォーマットのカスタマイズ 195 196 Example: 197 >>> handler = WebhookEventHandler( 198 ... webhook_url="http://your-api:3081/webhook/notify", 199 ... conversation_id="conv-456", 200 ... message_id="msg-789", 201 ... ) 202 >>> await handler(event) 203 204 Note: 205 送信されるJSON形式: 206 { 207 "type": "append_message", 208 "data": { 209 "conversationId": "...", 210 "messageId": "...", 211 "chunk_type": "content", 212 "content": {"content": "..."}, 213 "finish_reason": null 214 } 215 } 216 """ 217 218 def __init__( 219 self, 220 webhook_url: str, 221 conversation_id: str, 222 message_id: str, 223 headers: Optional[Dict[str, str]] = None, 224 timeout_seconds: int = 30, 225 message_type: str = "append_message", 226 token_provider: Optional[Callable[[], Optional[str]]] = None, 227 chunk_type_map: Optional[Dict[EventType, str]] = None, 228 ): 229 """ 230 WebhookEventHandlerを初期化 231 232 Args: 233 webhook_url: WebhookエンドポイントURL 234 conversation_id: 会話ID 235 message_id: メッセージID 236 headers: カスタムHTTPヘッダー(オプション) 237 timeout_seconds: リクエストタイムアウト秒数 238 message_type: メッセージタイプ(デフォルト: append_message) 239 token_provider: 認証トークン取得関数(オプション) 240 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 241 242 Example: 243 >>> handler = WebhookEventHandler( 244 ... webhook_url="http://your-api:3081/webhook/notify", 245 ... conversation_id="conv-456", 246 ... message_id="msg-789", 247 ... headers={"X-Custom-Header": "value"}, 248 ... ) 249 """ 250 self.webhook_url = webhook_url 251 self.conversation_id = conversation_id 252 self.message_id = message_id 253 self.headers = headers or {} 254 self.timeout_seconds = timeout_seconds 255 self.message_type = message_type 256 self.token_provider = token_provider 257 258 # デフォルトのchunk_typeマッピング 259 self._chunk_type_map = chunk_type_map or { 260 EventType.PHASE_START: "content", 261 EventType.PROGRESS_UPDATE: "content", 262 EventType.THOUGHT_MESSAGE: "reasoning_content", 263 EventType.TOOL_START: "task_created", 264 EventType.TOOL_RESULT: "task_created", 265 EventType.FILE_CREATED: "task_created", 266 EventType.USER_INTERACTION_REQUIRED: "content", 267 EventType.COMPLETION_SUCCESS: "content", 268 EventType.COMPLETION_FAILURE: "content", 269 EventType.HITL_REQUIRED_BROWSER_VNC: "vnc_hitl_event", 270 EventType.HITL_COMPLETED: "vnc_hitl_event", 271 EventType.FINAL_RESULT: None, 272 } 273 274 # finish_reasonマッピング 275 self._finish_reason_map = { 276 EventType.COMPLETION_SUCCESS: "stop", 277 EventType.COMPLETION_FAILURE: "error", 278 } 279 280 logger.debug( 281 f"WebhookEventHandler initialized: url={webhook_url}, " 282 f"conversation_id={conversation_id}" 283 ) 284 285 async def __call__(self, event: StreamingEvent) -> Optional[str]: 286 """ 287 イベントをWebhookに送信 288 289 Args: 290 event: ストリーミングイベント 291 292 Returns: 293 None(Webhookハンドラーはストリーミングチャンクを返さない) 294 """ 295 try: 296 # aiohttpをここでimport(オプション依存のため) 297 try: 298 import aiohttp 299 except ImportError: 300 logger.error( 301 "aiohttp is required for WebhookEventHandler: " 302 "pip install aiohttp" 303 ) 304 return None 305 306 # chunk_typeを決定(Noneの場合はスキップ) 307 chunk_type = self._chunk_type_map.get(event.event_type, "content") 308 if chunk_type is None: 309 logger.debug(f"Webhook event skipped (chunk_type=None): {event.event_type.value}") 310 return None 311 312 # finish_reasonを決定 313 finish_reason = self._finish_reason_map.get(event.event_type) 314 315 # contentを構築 316 content = {"content": event.message} 317 if event.metadata: 318 content.update(event.metadata) 319 320 # Webhookペイロードを構築 321 payload = { 322 "type": self.message_type, 323 "data": { 324 "conversationId": self.conversation_id, 325 "messageId": self.message_id, 326 "chunk_type": chunk_type, 327 "content": content, 328 "finish_reason": finish_reason, 329 }, 330 } 331 332 # ヘッダーを設定 333 headers = {"Content-Type": "application/json"} 334 headers.update(self.headers) 335 336 # トークンプロバイダーがあれば認証ヘッダーを追加 337 if self.token_provider: 338 token = self.token_provider() 339 if token: 340 headers["Authorization"] = f"Bearer {token}" 341 342 # HTTP POSTを送信 343 timeout = aiohttp.ClientTimeout(total=self.timeout_seconds) 344 async with aiohttp.ClientSession(timeout=timeout) as session: 345 async with session.post( 346 self.webhook_url, 347 json=payload, 348 headers=headers, 349 ) as response: 350 # レスポンスを読み取り(リソースリークを防ぐ) 351 await response.text() 352 353 if response.status == 200: 354 logger.debug( 355 f"Webhook notification sent: {event.event_type.value}" 356 ) 357 else: 358 logger.warning( 359 f"Webhook returned status {response.status}" 360 ) 361 362 return None 363 364 except asyncio.TimeoutError: 365 logger.error("Webhook notification timeout") 366 return None 367 except Exception as e: 368 logger.error(f"Error sending webhook notification: {e}") 369 return None
Webhookにイベントを送信するハンドラー
HTTP POSTでイベントをWebhookエンドポイントに送信。 AgenticStar API等との連携に使用。
Features: - 非同期HTTP POST送信 - カスタムヘッダー対応 - タイムアウト設定 - メッセージフォーマットのカスタマイズ
Example:
handler = WebhookEventHandler( ... webhook_url="http://your-api:3081/webhook/notify", ... conversation_id="conv-456", ... message_id="msg-789", ... ) await handler(event)
Note: 送信されるJSON形式: { "type": "append_message", "data": { "conversationId": "...", "messageId": "...", "chunk_type": "content", "content": {"content": "..."}, "finish_reason": null } }
218 def __init__( 219 self, 220 webhook_url: str, 221 conversation_id: str, 222 message_id: str, 223 headers: Optional[Dict[str, str]] = None, 224 timeout_seconds: int = 30, 225 message_type: str = "append_message", 226 token_provider: Optional[Callable[[], Optional[str]]] = None, 227 chunk_type_map: Optional[Dict[EventType, str]] = None, 228 ): 229 """ 230 WebhookEventHandlerを初期化 231 232 Args: 233 webhook_url: WebhookエンドポイントURL 234 conversation_id: 会話ID 235 message_id: メッセージID 236 headers: カスタムHTTPヘッダー(オプション) 237 timeout_seconds: リクエストタイムアウト秒数 238 message_type: メッセージタイプ(デフォルト: append_message) 239 token_provider: 認証トークン取得関数(オプション) 240 chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション) 241 242 Example: 243 >>> handler = WebhookEventHandler( 244 ... webhook_url="http://your-api:3081/webhook/notify", 245 ... conversation_id="conv-456", 246 ... message_id="msg-789", 247 ... headers={"X-Custom-Header": "value"}, 248 ... ) 249 """ 250 self.webhook_url = webhook_url 251 self.conversation_id = conversation_id 252 self.message_id = message_id 253 self.headers = headers or {} 254 self.timeout_seconds = timeout_seconds 255 self.message_type = message_type 256 self.token_provider = token_provider 257 258 # デフォルトのchunk_typeマッピング 259 self._chunk_type_map = chunk_type_map or { 260 EventType.PHASE_START: "content", 261 EventType.PROGRESS_UPDATE: "content", 262 EventType.THOUGHT_MESSAGE: "reasoning_content", 263 EventType.TOOL_START: "task_created", 264 EventType.TOOL_RESULT: "task_created", 265 EventType.FILE_CREATED: "task_created", 266 EventType.USER_INTERACTION_REQUIRED: "content", 267 EventType.COMPLETION_SUCCESS: "content", 268 EventType.COMPLETION_FAILURE: "content", 269 EventType.HITL_REQUIRED_BROWSER_VNC: "vnc_hitl_event", 270 EventType.HITL_COMPLETED: "vnc_hitl_event", 271 EventType.FINAL_RESULT: None, 272 } 273 274 # finish_reasonマッピング 275 self._finish_reason_map = { 276 EventType.COMPLETION_SUCCESS: "stop", 277 EventType.COMPLETION_FAILURE: "error", 278 } 279 280 logger.debug( 281 f"WebhookEventHandler initialized: url={webhook_url}, " 282 f"conversation_id={conversation_id}" 283 )
WebhookEventHandlerを初期化
Args: webhook_url: WebhookエンドポイントURL conversation_id: 会話ID message_id: メッセージID headers: カスタムHTTPヘッダー(オプション) timeout_seconds: リクエストタイムアウト秒数 message_type: メッセージタイプ(デフォルト: append_message) token_provider: 認証トークン取得関数(オプション) chunk_type_map: EventType→chunk_typeのカスタムマッピング(オプション)
Example:
handler = WebhookEventHandler( ... webhook_url="http://your-api:3081/webhook/notify", ... conversation_id="conv-456", ... message_id="msg-789", ... headers={"X-Custom-Header": "value"}, ... )
372class CompositeEventHandler: 373 """ 374 複数のハンドラーを並列実行する複合ハンドラー 375 376 DatabaseEventHandlerとWebhookEventHandlerを同時実行し、 377 autonomousのPodWebhookSubscriberと同等の機能を提供。 378 379 Features: 380 - 複数ハンドラーの並列実行 381 - エラー時も他のハンドラーは継続 382 - 最初のSSEチャンクを返す(SSEハンドラーがある場合) 383 384 Example: 385 >>> db_handler = DatabaseEventHandler(...) 386 >>> webhook_handler = WebhookEventHandler(...) 387 >>> composite = CompositeEventHandler([db_handler, webhook_handler]) 388 >>> 389 >>> # EventEmitterに設定 390 >>> emitter = EventEmitter( 391 ... execution_id="exec-123", 392 ... handler=composite 393 ... ) 394 395 Note: 396 asyncio.gatherでreturn_exceptions=Trueを使用するため、 397 一部のハンドラーがエラーでも他は正常に実行される。 398 """ 399 400 def __init__( 401 self, 402 handlers: List[Callable[[StreamingEvent], Optional[str]]], 403 return_first_chunk: bool = True, 404 ): 405 """ 406 CompositeEventHandlerを初期化 407 408 Args: 409 handlers: イベントハンドラーのリスト 410 return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True) 411 412 Example: 413 >>> composite = CompositeEventHandler([ 414 ... DatabaseEventHandler(...), 415 ... WebhookEventHandler(...), 416 ... create_sse_handler(), # SSEチャンクを返すハンドラー 417 ... ]) 418 """ 419 self.handlers = handlers 420 self.return_first_chunk = return_first_chunk 421 422 logger.debug( 423 f"CompositeEventHandler initialized with {len(handlers)} handlers" 424 ) 425 426 async def __call__(self, event: StreamingEvent) -> Optional[str]: 427 """ 428 全ハンドラーを並列実行 429 430 Args: 431 event: ストリーミングイベント 432 433 Returns: 434 最初のNon-Noneチャンク(return_first_chunk=Trueの場合) 435 またはNone 436 """ 437 try: 438 # 全ハンドラーを並列実行 439 tasks = [handler(event) for handler in self.handlers] 440 results = await asyncio.gather(*tasks, return_exceptions=True) 441 442 # エラーをログ出力 443 for i, result in enumerate(results): 444 if isinstance(result, Exception): 445 logger.error( 446 f"Handler {i} raised exception: {result}" 447 ) 448 449 # 最初のNon-Noneチャンクを返す 450 if self.return_first_chunk: 451 for result in results: 452 if result is not None and not isinstance(result, Exception): 453 return result 454 455 return None 456 457 except Exception as e: 458 logger.error(f"Error in CompositeEventHandler: {e}") 459 return None
複数のハンドラーを並列実行する複合ハンドラー
DatabaseEventHandlerとWebhookEventHandlerを同時実行し、 autonomousのPodWebhookSubscriberと同等の機能を提供。
Features: - 複数ハンドラーの並列実行 - エラー時も他のハンドラーは継続 - 最初のSSEチャンクを返す(SSEハンドラーがある場合)
Example:
db_handler = DatabaseEventHandler(...) webhook_handler = WebhookEventHandler(...) composite = CompositeEventHandler([db_handler, webhook_handler])
EventEmitterに設定
emitter = EventEmitter( ... execution_id="exec-123", ... handler=composite ... )
Note: asyncio.gatherでreturn_exceptions=Trueを使用するため、 一部のハンドラーがエラーでも他は正常に実行される。
400 def __init__( 401 self, 402 handlers: List[Callable[[StreamingEvent], Optional[str]]], 403 return_first_chunk: bool = True, 404 ): 405 """ 406 CompositeEventHandlerを初期化 407 408 Args: 409 handlers: イベントハンドラーのリスト 410 return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True) 411 412 Example: 413 >>> composite = CompositeEventHandler([ 414 ... DatabaseEventHandler(...), 415 ... WebhookEventHandler(...), 416 ... create_sse_handler(), # SSEチャンクを返すハンドラー 417 ... ]) 418 """ 419 self.handlers = handlers 420 self.return_first_chunk = return_first_chunk 421 422 logger.debug( 423 f"CompositeEventHandler initialized with {len(handlers)} handlers" 424 )
CompositeEventHandlerを初期化
Args: handlers: イベントハンドラーのリスト return_first_chunk: 最初のNon-Noneチャンクを返すか(デフォルト: True)
Example:
composite = CompositeEventHandler([ ... DatabaseEventHandler(...), ... WebhookEventHandler(...), ... create_sse_handler(), # SSEチャンクを返すハンドラー ... ])
462def create_marketplace_handler( 463 data_access: Any, 464 webhook_url: str, 465 user_id: str, 466 conversation_id: str, 467 message_id: str, 468 token_provider: Optional[Callable[[], Optional[str]]] = None, 469) -> CompositeEventHandler: 470 """ 471 マーケットプレイス向けの標準ハンドラーを作成 472 473 DB保存 + Webhook送信の複合ハンドラーを簡単に作成するファクトリ関数。 474 475 Args: 476 data_access: SDKのDataAccessインスタンス 477 webhook_url: WebhookエンドポイントURL 478 user_id: ユーザーID 479 conversation_id: 会話ID 480 message_id: メッセージID 481 token_provider: 認証トークン取得関数(オプション) 482 483 Returns: 484 CompositeEventHandler: DB + Webhookの複合ハンドラー 485 486 Example: 487 >>> from agenticstar_platform.db import DataAccess, PostgreSQLConfig 488 >>> from agenticstar_platform.events import EventEmitter 489 >>> from agenticstar_platform.events.handlers import create_marketplace_handler 490 >>> 491 >>> config = PostgreSQLConfig.from_env() 492 >>> async with DataAccess(config) as da: 493 ... handler = create_marketplace_handler( 494 ... data_access=da, 495 ... webhook_url="http://your-api:3081/webhook/notify", 496 ... user_id="user-123", 497 ... conversation_id="conv-456", 498 ... message_id="msg-789", 499 ... ) 500 ... emitter = EventEmitter(execution_id="exec-001", handler=handler) 501 ... await emitter.emit(EventType.PHASE_START, "処理開始") 502 """ 503 db_handler = DatabaseEventHandler( 504 data_access=data_access, 505 user_id=user_id, 506 conversation_id=conversation_id, 507 message_id=message_id, 508 ) 509 510 webhook_handler = WebhookEventHandler( 511 webhook_url=webhook_url, 512 conversation_id=conversation_id, 513 message_id=message_id, 514 token_provider=token_provider, 515 ) 516 517 return CompositeEventHandler([db_handler, webhook_handler])
マーケットプレイス向けの標準ハンドラーを作成
DB保存 + Webhook送信の複合ハンドラーを簡単に作成するファクトリ関数。
Args: data_access: SDKのDataAccessインスタンス webhook_url: WebhookエンドポイントURL user_id: ユーザーID conversation_id: 会話ID message_id: メッセージID token_provider: 認証トークン取得関数(オプション)
Returns: CompositeEventHandler: DB + Webhookの複合ハンドラー
Example:
from agenticstar_platform.db import DataAccess, PostgreSQLConfig from agenticstar_platform.events import EventEmitter from agenticstar_platform.events.handlers import create_marketplace_handler
config = PostgreSQLConfig.from_env() async with DataAccess(config) as da: ... handler = create_marketplace_handler( ... data_access=da, ... webhook_url="http://your-api:3081/webhook/notify", ... user_id="user-123", ... conversation_id="conv-456", ... message_id="msg-789", ... ) ... emitter = EventEmitter(execution_id="exec-001", handler=handler) ... await emitter.emit(EventType.PHASE_START, "処理開始")
566class SemanticMemoryClient: 567 """ 568 汎用セマンティックメモリクライアント(Mem0ベース) 569 570 会話やコンテキストをセマンティック検索可能な形式で保存・検索します。 571 LLMとEmbedderを使用してメモリの抽出・ベクトル化を行います。 572 573 Features: 574 - マルチプロバイダーLLM/Embedder対応(Azure, OpenAI, Anthropic, Bedrock, Gemini, Groq) 575 - 基本的なメモリ追加・検索(add, search, get_all, delete, delete_all) 576 - カスタムプロンプト対応(ファクト抽出・メモリ更新プロンプトのカスタマイズ) 577 - ベクトルストア設定対応(Qdrant等) 578 579 Note: 580 - 組織階層(personal/role/org/global)は提供しない(利用側で実装) 581 - PIIマスキングは提供しない(利用側で実装) 582 583 Example: 584 >>> from agenticstar_platform import SemanticMemoryClient, SemanticMemoryConfig 585 >>> 586 >>> # 設定を作成 587 >>> config = SemanticMemoryConfig.from_toml("config.toml") 588 >>> 589 >>> # クライアント初期化 590 >>> client = SemanticMemoryClient(config) 591 >>> 592 >>> # メモリを追加 593 >>> client.add( 594 ... messages=[ 595 ... {"role": "user", "content": "私の名前は田中です"}, 596 ... {"role": "assistant", "content": "こんにちは、田中さん"} 597 ... ], 598 ... user_id="user-123" 599 ... ) 600 >>> 601 >>> # メモリを検索 602 >>> results = client.search(query="名前", user_id="user-123") 603 >>> print(results) 604 {'results': [{'memory': 'ユーザーの名前は田中', 'score': 0.95}]} 605 """ 606 607 def _get_api_key(self, config: LLMProviderConfig) -> str: 608 """APIキーを取得(SecretStr対応) 609 610 Args: 611 config: LLMProviderConfig 612 613 Returns: 614 APIキー文字列 615 """ 616 api_key = config.api_key 617 if hasattr(api_key, 'get_secret_value'): 618 return api_key.get_secret_value() 619 return api_key 620 621 def _validate_provider_config(self, config: LLMProviderConfig, context: str) -> None: 622 """プロバイダー固有の設定を検証 623 624 Args: 625 config: LLMProviderConfig 626 context: 検証コンテキスト(エラーメッセージ用) 627 628 Raises: 629 SemanticMemoryConfigError: 必須設定が不足している場合 630 """ 631 provider, _ = self._parse_model_string(config.model) 632 633 if provider == "azure": 634 if not config.base_url: 635 raise SemanticMemoryConfigError( 636 f"{context}: Azure provider requires 'base_url' (azure_endpoint)" 637 ) 638 if not config.api_version: 639 logger.warning( 640 f"{context}: Azure provider 'api_version' not set, using default" 641 ) 642 643 elif provider == "bedrock": 644 missing = [] 645 if not config.aws_access_key_id: 646 missing.append("aws_access_key_id") 647 if not config.aws_secret_access_key: 648 missing.append("aws_secret_access_key") 649 if not config.aws_region_name: 650 missing.append("aws_region_name") 651 if missing: 652 raise SemanticMemoryConfigError( 653 f"{context}: Bedrock provider requires: {', '.join(missing)}" 654 ) 655 656 def __init__(self, config: SemanticMemoryConfig): 657 """ 658 Initialize Semantic Memory Client 659 660 Args: 661 config: SemanticMemoryConfig instance 662 663 Raises: 664 SemanticMemoryConfigError: 設定が不正な場合 665 """ 666 if not config: 667 raise SemanticMemoryConfigError("SemanticMemoryConfig is required") 668 669 if not config.llm_config or not config.embedder_config: 670 raise SemanticMemoryConfigError( 671 "Both llm_config and embedder_config are required" 672 ) 673 674 # Validate provider-specific configurations 675 self._validate_provider_config(config.llm_config, "llm_config") 676 self._validate_provider_config(config.embedder_config, "embedder_config") 677 678 self._config = config 679 self._memory = None 680 self._enabled = False 681 682 try: 683 # Configure Mem0 telemetry before importing 684 _configure_mem0_telemetry() 685 686 from mem0 import Memory 687 688 # Build Mem0 format config 689 mem0_config = { 690 'llm': self._convert_llm_to_mem0(config.llm_config), 691 'embedder': self._convert_embedder_to_mem0(config.embedder_config), 692 } 693 694 if config.vector_store: 695 mem0_config['vector_store'] = config.vector_store 696 697 if config.rerank: 698 mem0_config['rerank'] = config.rerank 699 700 if config.custom_fact_extraction_prompt: 701 mem0_config['custom_fact_extraction_prompt'] = config.custom_fact_extraction_prompt 702 703 if config.custom_update_memory_prompt: 704 mem0_config['custom_update_memory_prompt'] = config.custom_update_memory_prompt 705 706 self._memory = Memory.from_config(mem0_config) 707 self._enabled = True 708 709 logger.info( 710 f"SemanticMemoryClient initialized - " 711 f"llm={config.llm_config.model}, " 712 f"embedder={config.embedder_config.model}" 713 ) 714 715 except ImportError as e: 716 logger.error(f"mem0 library not installed: {e}") 717 raise SemanticMemoryConfigError( 718 "mem0 library is required: pip install mem0ai" 719 ) from e 720 except Exception as e: 721 logger.error(f"Failed to initialize SemanticMemoryClient: {e}") 722 raise SemanticMemoryConfigError( 723 f"Failed to initialize Mem0: {e}" 724 ) from e 725 726 @property 727 def enabled(self) -> bool: 728 """クライアントが有効かどうか""" 729 return self._enabled 730 731 def _normalize_provider(self, provider: str) -> str: 732 """プロバイダー名を正規化(エイリアス対応)""" 733 provider_aliases = { 734 # Azure 735 "azure": "azure", 736 "azure_openai": "azure", 737 "azure-openai": "azure", 738 # Bedrock 739 "bedrock": "bedrock", 740 "aws_bedrock": "bedrock", 741 "aws-bedrock": "bedrock", 742 # Google 743 "gemini": "gemini", 744 "google": "gemini", 745 "google-genai": "gemini", 746 "google_genai": "gemini", 747 } 748 return provider_aliases.get(provider.lower(), provider.lower()) 749 750 def _parse_model_string(self, model: str) -> tuple[str, str]: 751 """モデル文字列をプロバイダーとモデル名に分離 752 753 Args: 754 model: "provider/model-name" or "model-name" format 755 756 Returns: 757 (provider, model_name) tuple 758 """ 759 if "/" in model: 760 provider = model.split("/")[0] 761 model_name = model.split("/", 1)[1] 762 else: 763 provider = "openai" 764 model_name = model 765 766 return self._normalize_provider(provider), model_name 767 768 def _convert_llm_to_mem0(self, llm_config: LLMProviderConfig) -> Dict[str, Any]: 769 """LLMProviderConfigからMem0形式のLLM設定に変換""" 770 provider, model_name = self._parse_model_string(llm_config.model) 771 api_key = self._get_api_key(llm_config) 772 773 if provider == "azure": 774 return { 775 'provider': 'azure_openai', 776 'config': { 777 'azure_kwargs': { 778 'api_key': api_key, 779 'azure_endpoint': llm_config.base_url, 780 'azure_deployment': model_name, 781 'api_version': llm_config.api_version or "2024-02-15-preview" 782 } 783 } 784 } 785 elif provider == "openai": 786 return { 787 'provider': 'openai', 788 'config': { 789 'model': model_name, 790 'api_key': api_key, 791 } 792 } 793 elif provider == "anthropic": 794 return { 795 'provider': 'anthropic', 796 'config': { 797 'model': model_name, 798 'api_key': api_key, 799 } 800 } 801 elif provider == "bedrock": 802 return { 803 'provider': 'aws_bedrock', 804 'config': { 805 'model': model_name, 806 'aws_access_key_id': llm_config.aws_access_key_id, 807 'aws_secret_access_key': llm_config.aws_secret_access_key, 808 'aws_region_name': llm_config.aws_region_name, 809 } 810 } 811 elif provider == "gemini": 812 return { 813 'provider': 'gemini', 814 'config': { 815 'model': model_name, 816 'api_key': api_key, 817 } 818 } 819 elif provider == "groq": 820 return { 821 'provider': 'groq', 822 'config': { 823 'model': model_name, 824 'api_key': api_key, 825 } 826 } 827 else: 828 # Default to provider name with OpenAI-like config 829 return { 830 'provider': provider, 831 'config': { 832 'model': model_name, 833 'api_key': api_key, 834 } 835 } 836 837 def _convert_embedder_to_mem0(self, embedder_config: LLMProviderConfig) -> Dict[str, Any]: 838 """LLMProviderConfigからMem0形式のEmbedder設定に変換""" 839 provider, model_name = self._parse_model_string(embedder_config.model) 840 api_key = self._get_api_key(embedder_config) 841 842 if provider == "azure": 843 return { 844 'provider': 'azure_openai', 845 'config': { 846 'model': model_name, 847 'azure_kwargs': { 848 'api_key': api_key, 849 'azure_endpoint': embedder_config.base_url, 850 'azure_deployment': model_name, 851 'api_version': embedder_config.api_version or "2024-02-15-preview" 852 } 853 } 854 } 855 elif provider == "openai": 856 return { 857 'provider': 'openai', 858 'config': { 859 'model': model_name, 860 'api_key': api_key, 861 } 862 } 863 elif provider == "bedrock": 864 return { 865 'provider': 'aws_bedrock', 866 'config': { 867 'model': model_name, 868 'aws_access_key_id': embedder_config.aws_access_key_id, 869 'aws_secret_access_key': embedder_config.aws_secret_access_key, 870 'aws_region_name': embedder_config.aws_region_name, 871 } 872 } 873 elif provider == "gemini": 874 return { 875 'provider': 'gemini', 876 'config': { 877 'model': model_name, 878 'api_key': api_key, 879 } 880 } 881 else: 882 # Default to provider name with OpenAI-like config 883 return { 884 'provider': provider, 885 'config': { 886 'model': model_name, 887 'api_key': api_key, 888 } 889 } 890 891 def add( 892 self, 893 messages: List[Dict[str, str]], 894 user_id: str, 895 metadata: Optional[Dict[str, Any]] = None, 896 ) -> Dict[str, Any]: 897 """ 898 メモリを追加 899 900 会話メッセージをセマンティックメモリに保存します。 901 保存されたメモリは後でsearch()メソッドで検索できます。 902 903 Args: 904 messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 905 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます 906 metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能 907 908 Returns: 909 Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式 910 911 Raises: 912 SemanticMemoryError: メモリ追加に失敗した場合 913 914 Example: 915 >>> # 会話を保存 916 >>> result = memory_client.add( 917 ... messages=[ 918 ... {"role": "user", "content": "東京の天気は?"}, 919 ... {"role": "assistant", "content": "東京は晴れです。"} 920 ... ], 921 ... user_id="user-123", 922 ... metadata={"topic": "weather"} 923 ... ) 924 >>> print(result) 925 {'results': [...]} 926 """ 927 if not self._enabled: 928 return {"error": "SemanticMemoryClient is not enabled"} 929 930 try: 931 result = self._memory.add( 932 messages=messages, 933 user_id=user_id, 934 metadata=metadata or {}, 935 ) 936 logger.debug(f"Memory added for user_id={user_id}") 937 return result 938 939 except Exception as e: 940 logger.error(f"Failed to add memory: {e}") 941 raise SemanticMemoryError(f"Failed to add memory: {e}") from e 942 943 def search( 944 self, 945 query: str, 946 user_id: str, 947 limit: int = 10, 948 ) -> Dict[str, Any]: 949 """ 950 メモリを検索 951 952 保存されたメモリからセマンティック検索を行います。 953 クエリに意味的に関連するメモリを関連度順に返します。 954 955 Args: 956 query: 検索クエリ。自然言語で記述。例: "天気に関する会話" 957 user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 958 limit: 結果数上限(デフォルト: 10) 959 960 Returns: 961 Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]} 962 963 Raises: 964 SemanticMemoryError: 検索に失敗した場合 965 966 Example: 967 >>> # メモリを検索 968 >>> results = memory_client.search( 969 ... query="天気について教えて", 970 ... user_id="user-123", 971 ... limit=5 972 ... ) 973 >>> for item in results.get("results", []): 974 ... print(f"Score: {item['score']:.2f} - {item['memory']}") 975 """ 976 if not self._enabled: 977 return {"error": "SemanticMemoryClient is not enabled", "results": []} 978 979 try: 980 result = self._memory.search( 981 query=query, 982 user_id=user_id, 983 limit=limit, 984 ) 985 logger.debug( 986 f"Memory searched for user_id={user_id}, " 987 f"results={len(result.get('results', []))}" 988 ) 989 return result 990 991 except Exception as e: 992 logger.error(f"Failed to search memory: {e}") 993 raise SemanticMemoryError(f"Failed to search memory: {e}") from e 994 995 def get_all(self, user_id: str) -> Dict[str, Any]: 996 """ 997 ユーザーの全メモリを取得 998 999 指定されたユーザーに紐づく全てのメモリを取得します。 1000 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。 1001 1002 Args: 1003 user_id: ユーザーID 1004 1005 Returns: 1006 Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]} 1007 1008 Raises: 1009 SemanticMemoryError: 取得に失敗した場合 1010 1011 Example: 1012 >>> # ユーザーの全メモリを取得 1013 >>> all_memories = memory_client.get_all(user_id="user-123") 1014 >>> print(f"Total memories: {len(all_memories.get('results', []))}") 1015 >>> for mem in all_memories.get("results", []): 1016 ... print(f"[{mem['id']}] {mem['memory']}") 1017 """ 1018 if not self._enabled: 1019 return {"error": "SemanticMemoryClient is not enabled", "results": []} 1020 1021 try: 1022 result = self._memory.get_all(user_id=user_id) 1023 logger.debug( 1024 f"All memories retrieved for user_id={user_id}, " 1025 f"count={len(result.get('results', []))}" 1026 ) 1027 return result 1028 1029 except Exception as e: 1030 logger.error(f"Failed to get all memories: {e}") 1031 raise SemanticMemoryError(f"Failed to get all memories: {e}") from e 1032 1033 def delete(self, memory_id: str) -> Dict[str, Any]: 1034 """ 1035 メモリを削除 1036 1037 指定されたIDのメモリを削除します。 1038 メモリIDはget_all()やsearch()で取得できます。 1039 1040 Args: 1041 memory_id: 削除するメモリのID 1042 1043 Returns: 1044 Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"} 1045 1046 Raises: 1047 SemanticMemoryError: 削除に失敗した場合 1048 1049 Example: 1050 >>> # 特定のメモリを削除 1051 >>> result = memory_client.delete(memory_id="mem-abc123") 1052 >>> print(result) 1053 {'message': 'Memory deleted successfully'} 1054 """ 1055 if not self._enabled: 1056 return {"error": "SemanticMemoryClient is not enabled"} 1057 1058 try: 1059 result = self._memory.delete(memory_id=memory_id) 1060 logger.debug(f"Memory deleted: {memory_id}") 1061 return result 1062 1063 except Exception as e: 1064 logger.error(f"Failed to delete memory: {e}") 1065 raise SemanticMemoryError(f"Failed to delete memory: {e}") from e 1066 1067 def delete_all(self, user_id: str) -> Dict[str, Any]: 1068 """ 1069 ユーザーの全メモリを削除 1070 1071 指定されたユーザーに紐づく全てのメモリを削除します。 1072 この操作は取り消せません。注意して使用してください。 1073 1074 Args: 1075 user_id: ユーザーID 1076 1077 Returns: 1078 Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"} 1079 1080 Raises: 1081 SemanticMemoryError: 削除に失敗した場合 1082 1083 Example: 1084 >>> # ユーザーの全メモリを削除 1085 >>> result = memory_client.delete_all(user_id="user-123") 1086 >>> print(result) 1087 {'message': 'All memories deleted successfully'} 1088 1089 Warning: 1090 この操作は取り消せません。全てのメモリが永久に削除されます。 1091 """ 1092 if not self._enabled: 1093 return {"error": "SemanticMemoryClient is not enabled"} 1094 1095 try: 1096 result = self._memory.delete_all(user_id=user_id) 1097 logger.debug(f"All memories deleted for user_id={user_id}") 1098 return result 1099 1100 except Exception as e: 1101 logger.error(f"Failed to delete all memories: {e}") 1102 raise SemanticMemoryError(f"Failed to delete all memories: {e}") from e 1103 1104 async def cleanup(self) -> None: 1105 """ 1106 クライアントリソースのクリーンアップ 1107 1108 非同期コンテキストマネージャーの終了時や、 1109 明示的にリソースを解放する場合に呼び出します。 1110 1111 Example: 1112 >>> # 明示的なクリーンアップ 1113 >>> await memory_client.cleanup() 1114 1115 >>> # コンテキストマネージャー使用時は自動で呼ばれる 1116 >>> async with memory_client: 1117 ... await memory_client.add(...) 1118 ... # ここで自動的にcleanup()が呼ばれる 1119 """ 1120 logger.debug("SemanticMemoryClient cleanup called") 1121 # No specific cleanup needed for mem0 at this time
汎用セマンティックメモリクライアント(Mem0ベース)
会話やコンテキストをセマンティック検索可能な形式で保存・検索します。 LLMとEmbedderを使用してメモリの抽出・ベクトル化を行います。
Features: - マルチプロバイダーLLM/Embedder対応(Azure, OpenAI, Anthropic, Bedrock, Gemini, Groq) - 基本的なメモリ追加・検索(add, search, get_all, delete, delete_all) - カスタムプロンプト対応(ファクト抽出・メモリ更新プロンプトのカスタマイズ) - ベクトルストア設定対応(Qdrant等)
Note: - 組織階層(personal/role/org/global)は提供しない(利用側で実装) - PIIマスキングは提供しない(利用側で実装)
Example:
from agenticstar_platform import SemanticMemoryClient, SemanticMemoryConfig
設定を作成
config = SemanticMemoryConfig.from_toml("config.toml")
クライアント初期化
client = SemanticMemoryClient(config)
メモリを追加
client.add( ... messages=[ ... {"role": "user", "content": "私の名前は田中です"}, ... {"role": "assistant", "content": "こんにちは、田中さん"} ... ], ... user_id="user-123" ... )
メモリを検索
results = client.search(query="名前", user_id="user-123") print(results) {'results': [{'memory': 'ユーザーの名前は田中', 'score': 0.95}]}
656 def __init__(self, config: SemanticMemoryConfig): 657 """ 658 Initialize Semantic Memory Client 659 660 Args: 661 config: SemanticMemoryConfig instance 662 663 Raises: 664 SemanticMemoryConfigError: 設定が不正な場合 665 """ 666 if not config: 667 raise SemanticMemoryConfigError("SemanticMemoryConfig is required") 668 669 if not config.llm_config or not config.embedder_config: 670 raise SemanticMemoryConfigError( 671 "Both llm_config and embedder_config are required" 672 ) 673 674 # Validate provider-specific configurations 675 self._validate_provider_config(config.llm_config, "llm_config") 676 self._validate_provider_config(config.embedder_config, "embedder_config") 677 678 self._config = config 679 self._memory = None 680 self._enabled = False 681 682 try: 683 # Configure Mem0 telemetry before importing 684 _configure_mem0_telemetry() 685 686 from mem0 import Memory 687 688 # Build Mem0 format config 689 mem0_config = { 690 'llm': self._convert_llm_to_mem0(config.llm_config), 691 'embedder': self._convert_embedder_to_mem0(config.embedder_config), 692 } 693 694 if config.vector_store: 695 mem0_config['vector_store'] = config.vector_store 696 697 if config.rerank: 698 mem0_config['rerank'] = config.rerank 699 700 if config.custom_fact_extraction_prompt: 701 mem0_config['custom_fact_extraction_prompt'] = config.custom_fact_extraction_prompt 702 703 if config.custom_update_memory_prompt: 704 mem0_config['custom_update_memory_prompt'] = config.custom_update_memory_prompt 705 706 self._memory = Memory.from_config(mem0_config) 707 self._enabled = True 708 709 logger.info( 710 f"SemanticMemoryClient initialized - " 711 f"llm={config.llm_config.model}, " 712 f"embedder={config.embedder_config.model}" 713 ) 714 715 except ImportError as e: 716 logger.error(f"mem0 library not installed: {e}") 717 raise SemanticMemoryConfigError( 718 "mem0 library is required: pip install mem0ai" 719 ) from e 720 except Exception as e: 721 logger.error(f"Failed to initialize SemanticMemoryClient: {e}") 722 raise SemanticMemoryConfigError( 723 f"Failed to initialize Mem0: {e}" 724 ) from e
Initialize Semantic Memory Client
Args: config: SemanticMemoryConfig instance
Raises: SemanticMemoryConfigError: 設定が不正な場合
891 def add( 892 self, 893 messages: List[Dict[str, str]], 894 user_id: str, 895 metadata: Optional[Dict[str, Any]] = None, 896 ) -> Dict[str, Any]: 897 """ 898 メモリを追加 899 900 会話メッセージをセマンティックメモリに保存します。 901 保存されたメモリは後でsearch()メソッドで検索できます。 902 903 Args: 904 messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 905 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます 906 metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能 907 908 Returns: 909 Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式 910 911 Raises: 912 SemanticMemoryError: メモリ追加に失敗した場合 913 914 Example: 915 >>> # 会話を保存 916 >>> result = memory_client.add( 917 ... messages=[ 918 ... {"role": "user", "content": "東京の天気は?"}, 919 ... {"role": "assistant", "content": "東京は晴れです。"} 920 ... ], 921 ... user_id="user-123", 922 ... metadata={"topic": "weather"} 923 ... ) 924 >>> print(result) 925 {'results': [...]} 926 """ 927 if not self._enabled: 928 return {"error": "SemanticMemoryClient is not enabled"} 929 930 try: 931 result = self._memory.add( 932 messages=messages, 933 user_id=user_id, 934 metadata=metadata or {}, 935 ) 936 logger.debug(f"Memory added for user_id={user_id}") 937 return result 938 939 except Exception as e: 940 logger.error(f"Failed to add memory: {e}") 941 raise SemanticMemoryError(f"Failed to add memory: {e}") from e
メモリを追加
会話メッセージをセマンティックメモリに保存します。 保存されたメモリは後でsearch()メソッドで検索できます。
Args: messages: メッセージリスト。各要素は {"role": "user"|"assistant", "content": "..."} 形式 user_id: ユーザーID(必須)。ユーザーごとにメモリが分離されます metadata: 追加メタデータ(オプション)。検索時のフィルタリングに使用可能
Returns: Dict[str, Any]: 追加結果。成功時は {"results": [...]} 形式
Raises: SemanticMemoryError: メモリ追加に失敗した場合
Example:
会話を保存
result = memory_client.add( ... messages=[ ... {"role": "user", "content": "東京の天気は?"}, ... {"role": "assistant", "content": "東京は晴れです。"} ... ], ... user_id="user-123", ... metadata={"topic": "weather"} ... ) print(result) {'results': [...]}
943 def search( 944 self, 945 query: str, 946 user_id: str, 947 limit: int = 10, 948 ) -> Dict[str, Any]: 949 """ 950 メモリを検索 951 952 保存されたメモリからセマンティック検索を行います。 953 クエリに意味的に関連するメモリを関連度順に返します。 954 955 Args: 956 query: 検索クエリ。自然言語で記述。例: "天気に関する会話" 957 user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 958 limit: 結果数上限(デフォルト: 10) 959 960 Returns: 961 Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]} 962 963 Raises: 964 SemanticMemoryError: 検索に失敗した場合 965 966 Example: 967 >>> # メモリを検索 968 >>> results = memory_client.search( 969 ... query="天気について教えて", 970 ... user_id="user-123", 971 ... limit=5 972 ... ) 973 >>> for item in results.get("results", []): 974 ... print(f"Score: {item['score']:.2f} - {item['memory']}") 975 """ 976 if not self._enabled: 977 return {"error": "SemanticMemoryClient is not enabled", "results": []} 978 979 try: 980 result = self._memory.search( 981 query=query, 982 user_id=user_id, 983 limit=limit, 984 ) 985 logger.debug( 986 f"Memory searched for user_id={user_id}, " 987 f"results={len(result.get('results', []))}" 988 ) 989 return result 990 991 except Exception as e: 992 logger.error(f"Failed to search memory: {e}") 993 raise SemanticMemoryError(f"Failed to search memory: {e}") from e
メモリを検索
保存されたメモリからセマンティック検索を行います。 クエリに意味的に関連するメモリを関連度順に返します。
Args: query: 検索クエリ。自然言語で記述。例: "天気に関する会話" user_id: ユーザーID。add()で指定したuser_idと同じ値を使用 limit: 結果数上限(デフォルト: 10)
Returns: Dict[str, Any]: 検索結果。{"results": [{"id": "...", "memory": "...", "score": 0.9}, ...]}
Raises: SemanticMemoryError: 検索に失敗した場合
Example:
メモリを検索
results = memory_client.search( ... query="天気について教えて", ... user_id="user-123", ... limit=5 ... ) for item in results.get("results", []): ... print(f"Score: {item['score']:.2f} - {item['memory']}")
995 def get_all(self, user_id: str) -> Dict[str, Any]: 996 """ 997 ユーザーの全メモリを取得 998 999 指定されたユーザーに紐づく全てのメモリを取得します。 1000 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。 1001 1002 Args: 1003 user_id: ユーザーID 1004 1005 Returns: 1006 Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]} 1007 1008 Raises: 1009 SemanticMemoryError: 取得に失敗した場合 1010 1011 Example: 1012 >>> # ユーザーの全メモリを取得 1013 >>> all_memories = memory_client.get_all(user_id="user-123") 1014 >>> print(f"Total memories: {len(all_memories.get('results', []))}") 1015 >>> for mem in all_memories.get("results", []): 1016 ... print(f"[{mem['id']}] {mem['memory']}") 1017 """ 1018 if not self._enabled: 1019 return {"error": "SemanticMemoryClient is not enabled", "results": []} 1020 1021 try: 1022 result = self._memory.get_all(user_id=user_id) 1023 logger.debug( 1024 f"All memories retrieved for user_id={user_id}, " 1025 f"count={len(result.get('results', []))}" 1026 ) 1027 return result 1028 1029 except Exception as e: 1030 logger.error(f"Failed to get all memories: {e}") 1031 raise SemanticMemoryError(f"Failed to get all memories: {e}") from e
ユーザーの全メモリを取得
指定されたユーザーに紐づく全てのメモリを取得します。 検索ではなく、保存された全てのメモリを一覧表示する場合に使用します。
Args: user_id: ユーザーID
Returns: Dict[str, Any]: メモリリスト。{"results": [{"id": "...", "memory": "...", "created_at": "..."}, ...]}
Raises: SemanticMemoryError: 取得に失敗した場合
Example:
ユーザーの全メモリを取得
all_memories = memory_client.get_all(user_id="user-123") print(f"Total memories: {len(all_memories.get('results', []))}") for mem in all_memories.get("results", []): ... print(f"[{mem['id']}] {mem['memory']}")
1033 def delete(self, memory_id: str) -> Dict[str, Any]: 1034 """ 1035 メモリを削除 1036 1037 指定されたIDのメモリを削除します。 1038 メモリIDはget_all()やsearch()で取得できます。 1039 1040 Args: 1041 memory_id: 削除するメモリのID 1042 1043 Returns: 1044 Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"} 1045 1046 Raises: 1047 SemanticMemoryError: 削除に失敗した場合 1048 1049 Example: 1050 >>> # 特定のメモリを削除 1051 >>> result = memory_client.delete(memory_id="mem-abc123") 1052 >>> print(result) 1053 {'message': 'Memory deleted successfully'} 1054 """ 1055 if not self._enabled: 1056 return {"error": "SemanticMemoryClient is not enabled"} 1057 1058 try: 1059 result = self._memory.delete(memory_id=memory_id) 1060 logger.debug(f"Memory deleted: {memory_id}") 1061 return result 1062 1063 except Exception as e: 1064 logger.error(f"Failed to delete memory: {e}") 1065 raise SemanticMemoryError(f"Failed to delete memory: {e}") from e
メモリを削除
指定されたIDのメモリを削除します。 メモリIDはget_all()やsearch()で取得できます。
Args: memory_id: 削除するメモリのID
Returns: Dict[str, Any]: 削除結果。{"message": "Memory deleted successfully"}
Raises: SemanticMemoryError: 削除に失敗した場合
Example:
特定のメモリを削除
result = memory_client.delete(memory_id="mem-abc123") print(result) {'message': 'Memory deleted successfully'}
1067 def delete_all(self, user_id: str) -> Dict[str, Any]: 1068 """ 1069 ユーザーの全メモリを削除 1070 1071 指定されたユーザーに紐づく全てのメモリを削除します。 1072 この操作は取り消せません。注意して使用してください。 1073 1074 Args: 1075 user_id: ユーザーID 1076 1077 Returns: 1078 Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"} 1079 1080 Raises: 1081 SemanticMemoryError: 削除に失敗した場合 1082 1083 Example: 1084 >>> # ユーザーの全メモリを削除 1085 >>> result = memory_client.delete_all(user_id="user-123") 1086 >>> print(result) 1087 {'message': 'All memories deleted successfully'} 1088 1089 Warning: 1090 この操作は取り消せません。全てのメモリが永久に削除されます。 1091 """ 1092 if not self._enabled: 1093 return {"error": "SemanticMemoryClient is not enabled"} 1094 1095 try: 1096 result = self._memory.delete_all(user_id=user_id) 1097 logger.debug(f"All memories deleted for user_id={user_id}") 1098 return result 1099 1100 except Exception as e: 1101 logger.error(f"Failed to delete all memories: {e}") 1102 raise SemanticMemoryError(f"Failed to delete all memories: {e}") from e
ユーザーの全メモリを削除
指定されたユーザーに紐づく全てのメモリを削除します。 この操作は取り消せません。注意して使用してください。
Args: user_id: ユーザーID
Returns: Dict[str, Any]: 削除結果。{"message": "All memories deleted successfully"}
Raises: SemanticMemoryError: 削除に失敗した場合
Example:
ユーザーの全メモリを削除
result = memory_client.delete_all(user_id="user-123") print(result) {'message': 'All memories deleted successfully'}
Warning: この操作は取り消せません。全てのメモリが永久に削除されます。
1104 async def cleanup(self) -> None: 1105 """ 1106 クライアントリソースのクリーンアップ 1107 1108 非同期コンテキストマネージャーの終了時や、 1109 明示的にリソースを解放する場合に呼び出します。 1110 1111 Example: 1112 >>> # 明示的なクリーンアップ 1113 >>> await memory_client.cleanup() 1114 1115 >>> # コンテキストマネージャー使用時は自動で呼ばれる 1116 >>> async with memory_client: 1117 ... await memory_client.add(...) 1118 ... # ここで自動的にcleanup()が呼ばれる 1119 """ 1120 logger.debug("SemanticMemoryClient cleanup called") 1121 # No specific cleanup needed for mem0 at this time
クライアントリソースのクリーンアップ
非同期コンテキストマネージャーの終了時や、 明示的にリソースを解放する場合に呼び出します。
Example:
明示的なクリーンアップ
await memory_client.cleanup()
>>> # コンテキストマネージャー使用時は自動で呼ばれる >>> async with memory_client: ... await memory_client.add(...) ... # ここで自動的にcleanup()が呼ばれる
462@dataclass 463class SemanticMemoryConfig: 464 """Semantic Memory設定 465 466 Example: 467 >>> # 辞書から作成 468 >>> config = SemanticMemoryConfig.from_dict({ 469 ... "llm": { 470 ... "model": "azure/gpt-4", 471 ... "api_key": "key", 472 ... "base_url": "https://...", 473 ... "api_version": "2024-02-15-preview" 474 ... }, 475 ... "embedder": { 476 ... "model": "azure/text-embedding-ada-002", 477 ... "api_key": "key", 478 ... "base_url": "https://...", 479 ... "api_version": "2024-02-15-preview" 480 ... } 481 ... }) 482 483 >>> # TOMLファイルから作成 484 >>> config = SemanticMemoryConfig.from_toml("config.toml", section="memory") 485 486 >>> # Qdrant Vector Store付きで作成 487 >>> config = SemanticMemoryConfig( 488 ... llm_config=llm_config, 489 ... embedder_config=embedder_config, 490 ... vector_store=QdrantVectorStoreConfig( 491 ... url="http://localhost:6333", 492 ... collection_name="my_memories", 493 ... ).to_dict(), 494 ... ) 495 """ 496 llm_config: LLMProviderConfig 497 embedder_config: LLMProviderConfig 498 vector_store: Optional[Dict[str, Any]] = None 499 rerank: Optional[Dict[str, Any]] = None 500 custom_fact_extraction_prompt: Optional[str] = None 501 custom_update_memory_prompt: Optional[str] = None 502 503 @classmethod 504 def from_dict(cls, data: Dict[str, Any]) -> "SemanticMemoryConfig": 505 """辞書からSemanticMemoryConfigを作成 506 507 Args: 508 data: 設定辞書。"llm"と"embedder"キーを含む必要がある 509 510 Returns: 511 SemanticMemoryConfig instance 512 513 Example: 514 >>> config = SemanticMemoryConfig.from_dict({ 515 ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, 516 ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} 517 ... }) 518 """ 519 return cls( 520 llm_config=LLMProviderConfig.from_dict(data["llm"]), 521 embedder_config=LLMProviderConfig.from_dict(data["embedder"]), 522 vector_store=data.get("vector_store"), 523 rerank=data.get("rerank"), 524 custom_fact_extraction_prompt=data.get("custom_fact_extraction_prompt"), 525 custom_update_memory_prompt=data.get("custom_update_memory_prompt"), 526 ) 527 528 @classmethod 529 def from_toml(cls, toml_path: str, section: str = "memory") -> "SemanticMemoryConfig": 530 """TOMLファイルからSemanticMemoryConfigを作成 531 532 Args: 533 toml_path: TOMLファイルパス 534 section: セクション名(ドット区切りでネスト対応) 535 536 Returns: 537 SemanticMemoryConfig instance 538 539 Example: 540 >>> # config.toml の [memory] セクションを読み込み 541 >>> config = SemanticMemoryConfig.from_toml("config.toml") 542 543 >>> # 以下のTOML構造を期待: 544 >>> # [memory.llm] 545 >>> # model = "azure/gpt-4" 546 >>> # api_key = "..." 547 >>> # 548 >>> # [memory.embedder] 549 >>> # model = "azure/text-embedding-ada-002" 550 >>> # api_key = "..." 551 """ 552 if not TOML_AVAILABLE: 553 raise ImportError("toml library is required: pip install toml") 554 555 with open(toml_path, "r") as f: 556 toml_data = toml.load(f) 557 558 # ネストされたセクションをドット区切りで辿る 559 data = toml_data 560 for key in section.split("."): 561 data = data[key] 562 563 return cls.from_dict(data)
Semantic Memory設定
Example:
辞書から作成
config = SemanticMemoryConfig.from_dict({ ... "llm": { ... "model": "azure/gpt-4", ... "api_key": "key", ... "base_url": "https://...", ... "api_version": "2024-02-15-preview" ... }, ... "embedder": { ... "model": "azure/text-embedding-ada-002", ... "api_key": "key", ... "base_url": "https://...", ... "api_version": "2024-02-15-preview" ... } ... })
>>> # TOMLファイルから作成 >>> config = SemanticMemoryConfig.from_toml("config.toml", section="memory") >>> # Qdrant Vector Store付きで作成 >>> config = SemanticMemoryConfig( ... llm_config=llm_config, ... embedder_config=embedder_config, ... vector_store=QdrantVectorStoreConfig( ... url="http://localhost:6333", ... collection_name="my_memories", ... ).to_dict(), ... )
503 @classmethod 504 def from_dict(cls, data: Dict[str, Any]) -> "SemanticMemoryConfig": 505 """辞書からSemanticMemoryConfigを作成 506 507 Args: 508 data: 設定辞書。"llm"と"embedder"キーを含む必要がある 509 510 Returns: 511 SemanticMemoryConfig instance 512 513 Example: 514 >>> config = SemanticMemoryConfig.from_dict({ 515 ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, 516 ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} 517 ... }) 518 """ 519 return cls( 520 llm_config=LLMProviderConfig.from_dict(data["llm"]), 521 embedder_config=LLMProviderConfig.from_dict(data["embedder"]), 522 vector_store=data.get("vector_store"), 523 rerank=data.get("rerank"), 524 custom_fact_extraction_prompt=data.get("custom_fact_extraction_prompt"), 525 custom_update_memory_prompt=data.get("custom_update_memory_prompt"), 526 )
辞書からSemanticMemoryConfigを作成
Args: data: 設定辞書。"llm"と"embedder"キーを含む必要がある
Returns: SemanticMemoryConfig instance
Example:
config = SemanticMemoryConfig.from_dict({ ... "llm": {"model": "azure/gpt-4", "api_key": "key", "base_url": "..."}, ... "embedder": {"model": "azure/text-embedding-ada-002", "api_key": "key", "base_url": "..."} ... })
528 @classmethod 529 def from_toml(cls, toml_path: str, section: str = "memory") -> "SemanticMemoryConfig": 530 """TOMLファイルからSemanticMemoryConfigを作成 531 532 Args: 533 toml_path: TOMLファイルパス 534 section: セクション名(ドット区切りでネスト対応) 535 536 Returns: 537 SemanticMemoryConfig instance 538 539 Example: 540 >>> # config.toml の [memory] セクションを読み込み 541 >>> config = SemanticMemoryConfig.from_toml("config.toml") 542 543 >>> # 以下のTOML構造を期待: 544 >>> # [memory.llm] 545 >>> # model = "azure/gpt-4" 546 >>> # api_key = "..." 547 >>> # 548 >>> # [memory.embedder] 549 >>> # model = "azure/text-embedding-ada-002" 550 >>> # api_key = "..." 551 """ 552 if not TOML_AVAILABLE: 553 raise ImportError("toml library is required: pip install toml") 554 555 with open(toml_path, "r") as f: 556 toml_data = toml.load(f) 557 558 # ネストされたセクションをドット区切りで辿る 559 data = toml_data 560 for key in section.split("."): 561 data = data[key] 562 563 return cls.from_dict(data)
TOMLファイルからSemanticMemoryConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: SemanticMemoryConfig instance
Example:
config.toml の [memory] セクションを読み込み
config = SemanticMemoryConfig.from_toml("config.toml")
>>> # 以下のTOML構造を期待: >>> # [memory.llm] >>> # model = "azure/gpt-4" >>> # api_key = "..." >>> # >>> # [memory.embedder] >>> # model = "azure/text-embedding-ada-002" >>> # api_key = "..."
244class EpisodicMemoryClient: 245 """ 246 汎用エピソディックメモリクライアント(Graphiti/FalkorDBベース) 247 248 Features: 249 - マルチプロバイダーLLM/Embedder/Reranker対応 250 - FalkorDBグラフストレージ 251 - エピソード追加・検索 252 253 Note: 254 - 組織階層(personal/role/org/global)は提供しない 255 - PIIマスキングは提供しない(利用側で実装) 256 - Community機能は提供しない(利用側で必要に応じて実装) 257 258 Warning: 259 - graphiti_coreの制約により、EMBEDDING_DIMはプロセス内でグローバルに設定される 260 - 異なるembedding_modelを使用する複数クライアントを同一プロセスで使用すると、 261 最後に作成されたクライアントのEMBEDDING_DIMが適用される 262 - 異なる次元数のモデルを使用する場合は、別プロセスで実行することを推奨 263 """ 264 265 # Embedding dimensions by model 266 EMBEDDING_DIMS = { 267 # OpenAI models 268 "text-embedding-3-small": 1536, 269 "text-embedding-3-large": 3072, 270 "text-embedding-ada-002": 1536, 271 # Azure variants 272 "text-embedding-3-small-1": 1536, 273 "ada": 1536, 274 # Voyage AI models 275 "voyage-3": 1024, 276 "voyage-3-lite": 512, 277 "voyage-code-3": 1024, 278 "voyage-finance-2": 1024, 279 "voyage-law-2": 1024, 280 "voyage-large-2": 1536, 281 "voyage-2": 1024, 282 # Gemini models 283 "text-embedding-004": 768, 284 "embedding-001": 768, 285 } 286 287 def __init__(self, config: EpisodicMemoryConfig): 288 """ 289 Initialize Episodic Memory Client 290 291 Args: 292 config: EpisodicMemoryConfig instance 293 294 Raises: 295 EpisodicMemoryConfigError: 設定が不正な場合 296 """ 297 if not config: 298 raise EpisodicMemoryConfigError("EpisodicMemoryConfig is required") 299 300 if not config.llm_config or not config.embedder_config: 301 raise EpisodicMemoryConfigError( 302 "Both llm_config and embedder_config are required" 303 ) 304 305 self._config = config 306 self._graphiti = None 307 self._driver = None 308 self._enabled = False 309 self._embedding_dim: Optional[int] = None 310 # Track clients for proper cleanup 311 self._llm_client = None 312 self._embedder = None 313 self._reranker = None 314 self._openai_client = None # For Azure/OpenAI async client 315 316 try: 317 from graphiti_core import Graphiti 318 from graphiti_core.driver.falkordb_driver import FalkorDriver 319 320 # Initialize FalkorDB driver 321 self._driver = FalkorDriver( 322 host=config.host, 323 port=config.port, 324 database=config.database, 325 ) 326 logger.info( 327 f"FalkorDB driver initialized - " 328 f"host={config.host}, port={config.port}, database={config.database}" 329 ) 330 331 # Initialize LLM client 332 self._llm_client, self._openai_client = self._create_llm_client(config.llm_config) 333 334 # Initialize embedder 335 self._embedder = self._create_embedder(config.embedder_config) 336 337 # Initialize reranker 338 self._reranker = self._create_reranker(config.llm_config, self._openai_client) 339 340 # Initialize Graphiti 341 self._graphiti = Graphiti( 342 graph_driver=self._driver, 343 llm_client=self._llm_client, 344 embedder=self._embedder, 345 cross_encoder=self._reranker, 346 ) 347 348 self._enabled = True 349 logger.info( 350 f"EpisodicMemoryClient initialized - " 351 f"llm={config.llm_config.model}, " 352 f"embedder={config.embedder_config.model}" 353 ) 354 355 except ImportError as e: 356 logger.warning( 357 f"graphiti-core not installed. " 358 f"Episodic memory will be disabled: {e}" 359 ) 360 self._enabled = False 361 except Exception as e: 362 logger.error(f"Failed to initialize EpisodicMemoryClient: {e}") 363 raise EpisodicMemoryConfigError( 364 f"Failed to initialize Graphiti: {e}" 365 ) from e 366 367 @property 368 def enabled(self) -> bool: 369 """クライアントが有効かどうか""" 370 return self._enabled 371 372 @property 373 def embedding_dim(self) -> Optional[int]: 374 """現在のembedding次元数""" 375 return self._embedding_dim 376 377 @property 378 def graphiti(self): 379 """内部Graphitiインスタンスへのアクセス(高度な操作用)""" 380 return self._graphiti 381 382 def _parse_model_string(self, model: str) -> tuple[str, str]: 383 """モデル文字列をプロバイダーとモデル名に分離""" 384 if "/" in model: 385 provider = model.split("/")[0] 386 model_name = model.split("/", 1)[1] 387 else: 388 provider = "openai" 389 model_name = model 390 return provider.lower(), model_name 391 392 def _get_api_key(self, config: LLMProviderConfig) -> str: 393 """APIキーを取得(SecretStr対応)""" 394 api_key = config.api_key 395 if hasattr(api_key, 'get_secret_value'): 396 return api_key.get_secret_value() 397 return api_key 398 399 def _get_embedding_dim(self, model_name: str) -> int: 400 """モデル名から埋め込み次元数を取得""" 401 # 完全一致 402 if model_name in self.EMBEDDING_DIMS: 403 return self.EMBEDDING_DIMS[model_name] 404 405 # 部分一致 406 model_lower = model_name.lower() 407 for key, dim in self.EMBEDDING_DIMS.items(): 408 if key in model_lower or model_lower in key: 409 return dim 410 411 # デフォルト 412 logger.warning( 413 f"Unknown embedding model '{model_name}', using default dimension 1536" 414 ) 415 return 1536 416 417 def _create_llm_client(self, llm_config: LLMProviderConfig): 418 """LLMクライアントを作成(マルチプロバイダー対応)""" 419 from graphiti_core.llm_client import LLMConfig, OpenAIClient 420 421 provider, model_name = self._parse_model_string(llm_config.model) 422 api_key = self._get_api_key(llm_config) 423 424 graphiti_config = LLMConfig( 425 model=model_name, 426 small_model=model_name, 427 ) 428 429 if provider == "azure": 430 from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient 431 from openai import AsyncAzureOpenAI 432 433 azure_client = AsyncAzureOpenAI( 434 azure_endpoint=llm_config.base_url, 435 api_key=api_key, 436 api_version=llm_config.api_version or "2025-03-01-preview" 437 ) 438 return AzureOpenAILLMClient( 439 azure_client=azure_client, 440 config=graphiti_config 441 ), azure_client 442 443 elif provider == "anthropic": 444 from graphiti_core.llm_client.anthropic_client import AnthropicClient 445 graphiti_config.api_key = api_key 446 client = AnthropicClient(config=graphiti_config) 447 return client, client.client 448 449 elif provider == "openai": 450 graphiti_config.api_key = api_key 451 client = OpenAIClient(config=graphiti_config) 452 return client, client.client 453 454 elif provider == "gemini": 455 from graphiti_core.llm_client.gemini_client import GeminiClient 456 graphiti_config.api_key = api_key 457 client = GeminiClient(config=graphiti_config) 458 return client, client.client 459 460 elif provider == "groq": 461 from graphiti_core.llm_client.groq_client import GroqClient 462 graphiti_config.api_key = api_key 463 client = GroqClient(config=graphiti_config) 464 return client, client.client 465 466 else: 467 raise EpisodicMemoryConfigError(f"Unsupported LLM provider: {provider}") 468 469 def _create_embedder(self, embedder_config: LLMProviderConfig): 470 """Embedderを作成(マルチプロバイダー対応) 471 472 Note: 473 graphiti_coreの制約により、EMBEDDING_DIMはプロセスレベルで設定される。 474 同一プロセスで異なるembedding次元を使用する複数クライアントは非推奨。 475 """ 476 from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig 477 478 provider, model_name = self._parse_model_string(embedder_config.model) 479 api_key = self._get_api_key(embedder_config) 480 481 # Get and store embedding dimension 482 self._embedding_dim = self._get_embedding_dim(model_name) 483 484 # Set embedding dimension environment variable (graphiti_core requirement) 485 # WARNING: This is a global setting that affects all clients in the process 486 os.environ['EMBEDDING_DIM'] = str(self._embedding_dim) 487 # Update graphiti_core constant 488 import graphiti_core.embedder.client as embedder_client 489 embedder_client.EMBEDDING_DIM = self._embedding_dim 490 logger.debug(f"Set EMBEDDING_DIM={self._embedding_dim} for model: {model_name}") 491 492 if provider == "azure": 493 from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient 494 from openai import AsyncAzureOpenAI 495 496 azure_client = AsyncAzureOpenAI( 497 azure_endpoint=embedder_config.base_url, 498 api_key=api_key, 499 api_version=embedder_config.api_version or "2023-05-15" 500 ) 501 return AzureOpenAIEmbedderClient( 502 azure_client=azure_client, 503 model=model_name 504 ) 505 506 elif provider == "openai": 507 config = OpenAIEmbedderConfig( 508 embedding_model=model_name, 509 api_key=api_key, 510 ) 511 return OpenAIEmbedder(config=config) 512 513 elif provider == "gemini": 514 from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig 515 config = GeminiEmbedderConfig( 516 embedding_model=model_name, 517 api_key=api_key, 518 ) 519 return GeminiEmbedder(config=config) 520 521 elif provider == "voyage": 522 from graphiti_core.embedder.voyage import VoyageEmbedder, VoyageEmbedderConfig 523 config = VoyageEmbedderConfig( 524 embedding_model=model_name, 525 api_key=api_key, 526 ) 527 return VoyageEmbedder(config=config) 528 529 else: 530 # Default to OpenAI format 531 config = OpenAIEmbedderConfig( 532 embedding_model=model_name, 533 api_key=api_key, 534 ) 535 return OpenAIEmbedder(config=config) 536 537 def _create_reranker(self, llm_config: LLMProviderConfig, openai_client): 538 """Rerankerを作成(マルチプロバイダー対応)""" 539 from graphiti_core.llm_client import LLMConfig 540 541 provider, model_name = self._parse_model_string(llm_config.model) 542 api_key = self._get_api_key(llm_config) 543 544 if provider in ("azure", "openai"): 545 from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient 546 reranker_config = LLMConfig(model=model_name) 547 return OpenAIRerankerClient(config=reranker_config, client=openai_client) 548 549 elif provider == "gemini": 550 from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient 551 reranker_config = LLMConfig( 552 api_key=api_key, 553 model="gemini-2.5-flash-lite-preview-06-17", 554 ) 555 return GeminiRerankerClient(config=reranker_config) 556 557 else: 558 # Anthropic, Groq等はネイティブrerankerがない 559 logger.warning( 560 f"Provider '{provider}' does not have native reranker support. " 561 "Reranking disabled." 562 ) 563 return None 564 565 # Maximum recommended episode body size (characters) 566 MAX_EPISODE_BODY_SIZE = 100000 567 568 async def add_episode( 569 self, 570 group_id: str, 571 content: str, 572 name: Optional[str] = None, 573 user_id: Optional[str] = None, 574 ) -> bool: 575 """ 576 エピソードを追加 577 578 テキストをエピソードとして追加します。 579 580 Args: 581 group_id: グループID(FalkorDB互換形式に自動正規化) 582 content: エピソード内容(テキスト) 583 name: エピソード名(省略時は自動生成) 584 user_id: ユーザーID(ソース記述に使用) 585 586 Returns: 587 成功した場合True 588 589 Raises: 590 EpisodicMemoryError: 追加に失敗した場合 591 592 Example: 593 >>> await client.add_episode( 594 ... group_id="user-123", 595 ... content="ユーザーはPythonに興味があります", 596 ... name="user_interest", 597 ... ) 598 599 Note: 600 episode_bodyが100,000文字を超える場合は警告を出力します。 601 大きなエピソードはパフォーマンスに影響する可能性があります。 602 """ 603 if not self._enabled: 604 return False 605 606 if not content: 607 raise EpisodicMemoryError("'content' is required") 608 609 try: 610 from graphiti_core.nodes import EpisodeType 611 612 # Normalize group_id 613 normalized_group_id = normalize_group_id(group_id) 614 615 # Warn if episode body is too large 616 body_size = len(content) 617 if body_size > self.MAX_EPISODE_BODY_SIZE: 618 logger.warning( 619 f"Episode body is large ({body_size} chars, recommended max: " 620 f"{self.MAX_EPISODE_BODY_SIZE}). Consider splitting into smaller episodes." 621 ) 622 623 # エピソード名を決定 624 actual_name = name 625 if not actual_name: 626 actual_name = content[:50] + "..." if len(content) > 50 else content 627 628 # Add episode to Graphiti 629 # NOTE: update_communities=False to avoid graphiti-core bugs 630 await self._graphiti.add_episode( 631 group_id=normalized_group_id, 632 name=actual_name, 633 episode_body=content, 634 source=EpisodeType.message, 635 source_description=f"sdk-client-{user_id}" if user_id else "sdk-client", 636 reference_time=datetime.now(timezone.utc), 637 update_communities=False, 638 ) 639 640 logger.debug( 641 f"Episode added - group_id={normalized_group_id}" 642 ) 643 return True 644 645 except Exception as e: 646 logger.error(f"Failed to add episode: {e}") 647 raise EpisodicMemoryError(f"Failed to add episode: {e}") from e 648 649 async def search( 650 self, 651 query: str, 652 group_ids: Optional[List[str]] = None, 653 max_facts: int = 10, 654 ) -> Dict[str, Any]: 655 """ 656 エピソード記憶を検索 657 658 Args: 659 query: 検索クエリ 660 group_ids: 検索対象グループIDリスト(省略時は全グループ) 661 max_facts: 結果数上限(デフォルト: 10) 662 663 Returns: 664 Dict[str, Any]: 検索結果 665 - success: bool 666 - data: { 667 "query": str, 668 "results": [{"fact": str, "created_at": str, "uuid": str}], 669 "total_found": int 670 } 671 - error: str, error_code: str (失敗時) 672 673 Example: 674 >>> result = await client.search( 675 ... query="Pythonに関する情報", 676 ... group_ids=["user-123"], 677 ... max_facts=5, 678 ... ) 679 >>> if result["success"]: 680 ... for item in result["data"]["results"]: 681 ... print(item["fact"]) 682 683 Raises: 684 EpisodicMemoryError: 検索に失敗した場合 685 """ 686 if not self._enabled: 687 return { 688 "success": False, 689 "error": "Episodic memory is not enabled", 690 "error_code": "DISABLED", 691 } 692 693 try: 694 # Normalize group_ids 695 normalized_group_ids = None 696 if group_ids: 697 normalized_group_ids = [normalize_group_id(gid) for gid in group_ids] 698 699 # Search 700 edges = await self._graphiti.search( 701 query=query, 702 group_ids=normalized_group_ids, 703 num_results=max_facts, 704 ) 705 706 # Convert to fact dicts 707 facts = [] 708 for edge in edges: 709 fact = { 710 "fact": edge.fact if hasattr(edge, 'fact') else str(edge), 711 "created_at": edge.created_at.isoformat() if hasattr(edge, 'created_at') else None, 712 "uuid": edge.uuid if hasattr(edge, 'uuid') else None, 713 } 714 facts.append(fact) 715 716 logger.debug(f"Search returned {len(facts)} facts") 717 return { 718 "success": True, 719 "data": { 720 "query": query, 721 "results": facts, 722 "total_found": len(facts), 723 }, 724 } 725 726 except Exception as e: 727 logger.error(f"Failed to search: {e}") 728 return { 729 "success": False, 730 "error": f"Search error: {str(e)}", 731 "error_code": "EPISODIC_SEARCH_ERROR", 732 } 733 734 async def get_recent_episodes( 735 self, 736 group_id: str, 737 last_n: int = 5, 738 ) -> List[Any]: 739 """ 740 最近のエピソードを取得 741 742 Args: 743 group_id: グループID 744 last_n: 取得するエピソード数 745 746 Returns: 747 エピソードリスト 748 """ 749 if not self._enabled: 750 return [] 751 752 try: 753 normalized_group_id = normalize_group_id(group_id) 754 755 episodes = await self._graphiti.retrieve_episodes( 756 group_ids=[normalized_group_id], 757 last_n=last_n, 758 reference_time=datetime.now(timezone.utc), 759 ) 760 761 logger.debug( 762 f"Retrieved {len(episodes)} recent episodes for group_id={group_id}" 763 ) 764 return episodes 765 766 except Exception as e: 767 logger.error(f"Failed to get recent episodes: {e}") 768 raise EpisodicMemoryError( 769 f"Failed to get recent episodes: {e}" 770 ) from e 771 772 def format_facts_for_context( 773 self, 774 facts: List[Dict[str, Any]], 775 max_facts: int = 5, 776 title: str = "### Episodic Memory", 777 ) -> str: 778 """ 779 ファクトをLLMコンテキスト用にフォーマット 780 781 Args: 782 facts: ファクトリスト 783 max_facts: 最大ファクト数 784 title: セクションタイトル 785 786 Returns: 787 フォーマットされた文字列 788 """ 789 if not facts: 790 return "" 791 792 lines = [title] 793 794 for i, fact in enumerate(facts[:max_facts], 1): 795 fact_text = fact.get("fact", "") 796 created_at = fact.get("created_at") 797 798 if created_at: 799 try: 800 dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) 801 time_str = dt.strftime("%Y-%m-%d %H:%M") 802 except: 803 time_str = created_at 804 lines.append(f"{i}. [{time_str}] {fact_text}") 805 else: 806 lines.append(f"{i}. {fact_text}") 807 808 return "\n".join(lines) 809 810 async def close(self) -> None: 811 """コネクションとクライアントリソースをクローズ 812 813 Note: 814 FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを 815 すべてクローズします。 816 """ 817 errors = [] 818 819 # Close FalkorDB driver 820 try: 821 if self._driver and hasattr(self._driver, 'close'): 822 await self._driver.close() 823 self._driver = None 824 except Exception as e: 825 errors.append(f"driver: {e}") 826 827 # Close OpenAI/Azure async client 828 try: 829 if self._openai_client and hasattr(self._openai_client, 'close'): 830 await self._openai_client.close() 831 self._openai_client = None 832 except Exception as e: 833 errors.append(f"openai_client: {e}") 834 835 # Note: LLM client, embedder, reranker typically don't have close methods 836 # but we clear references to allow garbage collection 837 self._llm_client = None 838 self._embedder = None 839 self._reranker = None 840 self._graphiti = None 841 self._enabled = False 842 843 if errors: 844 logger.warning(f"Errors during close: {', '.join(errors)}") 845 else: 846 logger.debug("EpisodicMemoryClient closed successfully")
汎用エピソディックメモリクライアント(Graphiti/FalkorDBベース)
Features:
- マルチプロバイダーLLM/Embedder/Reranker対応
- FalkorDBグラフストレージ
- エピソード追加・検索
Note:
- 組織階層(personal/role/org/global)は提供しない
- PIIマスキングは提供しない(利用側で実装)
- Community機能は提供しない(利用側で必要に応じて実装)
Warning:
- graphiti_coreの制約により、EMBEDDING_DIMはプロセス内でグローバルに設定される
- 異なるembedding_modelを使用する複数クライアントを同一プロセスで使用すると、 最後に作成されたクライアントのEMBEDDING_DIMが適用される
- 異なる次元数のモデルを使用する場合は、別プロセスで実行することを推奨
287 def __init__(self, config: EpisodicMemoryConfig): 288 """ 289 Initialize Episodic Memory Client 290 291 Args: 292 config: EpisodicMemoryConfig instance 293 294 Raises: 295 EpisodicMemoryConfigError: 設定が不正な場合 296 """ 297 if not config: 298 raise EpisodicMemoryConfigError("EpisodicMemoryConfig is required") 299 300 if not config.llm_config or not config.embedder_config: 301 raise EpisodicMemoryConfigError( 302 "Both llm_config and embedder_config are required" 303 ) 304 305 self._config = config 306 self._graphiti = None 307 self._driver = None 308 self._enabled = False 309 self._embedding_dim: Optional[int] = None 310 # Track clients for proper cleanup 311 self._llm_client = None 312 self._embedder = None 313 self._reranker = None 314 self._openai_client = None # For Azure/OpenAI async client 315 316 try: 317 from graphiti_core import Graphiti 318 from graphiti_core.driver.falkordb_driver import FalkorDriver 319 320 # Initialize FalkorDB driver 321 self._driver = FalkorDriver( 322 host=config.host, 323 port=config.port, 324 database=config.database, 325 ) 326 logger.info( 327 f"FalkorDB driver initialized - " 328 f"host={config.host}, port={config.port}, database={config.database}" 329 ) 330 331 # Initialize LLM client 332 self._llm_client, self._openai_client = self._create_llm_client(config.llm_config) 333 334 # Initialize embedder 335 self._embedder = self._create_embedder(config.embedder_config) 336 337 # Initialize reranker 338 self._reranker = self._create_reranker(config.llm_config, self._openai_client) 339 340 # Initialize Graphiti 341 self._graphiti = Graphiti( 342 graph_driver=self._driver, 343 llm_client=self._llm_client, 344 embedder=self._embedder, 345 cross_encoder=self._reranker, 346 ) 347 348 self._enabled = True 349 logger.info( 350 f"EpisodicMemoryClient initialized - " 351 f"llm={config.llm_config.model}, " 352 f"embedder={config.embedder_config.model}" 353 ) 354 355 except ImportError as e: 356 logger.warning( 357 f"graphiti-core not installed. " 358 f"Episodic memory will be disabled: {e}" 359 ) 360 self._enabled = False 361 except Exception as e: 362 logger.error(f"Failed to initialize EpisodicMemoryClient: {e}") 363 raise EpisodicMemoryConfigError( 364 f"Failed to initialize Graphiti: {e}" 365 ) from e
Initialize Episodic Memory Client
Args: config: EpisodicMemoryConfig instance
Raises: EpisodicMemoryConfigError: 設定が不正な場合
372 @property 373 def embedding_dim(self) -> Optional[int]: 374 """現在のembedding次元数""" 375 return self._embedding_dim
現在のembedding次元数
377 @property 378 def graphiti(self): 379 """内部Graphitiインスタンスへのアクセス(高度な操作用)""" 380 return self._graphiti
内部Graphitiインスタンスへのアクセス(高度な操作用)
568 async def add_episode( 569 self, 570 group_id: str, 571 content: str, 572 name: Optional[str] = None, 573 user_id: Optional[str] = None, 574 ) -> bool: 575 """ 576 エピソードを追加 577 578 テキストをエピソードとして追加します。 579 580 Args: 581 group_id: グループID(FalkorDB互換形式に自動正規化) 582 content: エピソード内容(テキスト) 583 name: エピソード名(省略時は自動生成) 584 user_id: ユーザーID(ソース記述に使用) 585 586 Returns: 587 成功した場合True 588 589 Raises: 590 EpisodicMemoryError: 追加に失敗した場合 591 592 Example: 593 >>> await client.add_episode( 594 ... group_id="user-123", 595 ... content="ユーザーはPythonに興味があります", 596 ... name="user_interest", 597 ... ) 598 599 Note: 600 episode_bodyが100,000文字を超える場合は警告を出力します。 601 大きなエピソードはパフォーマンスに影響する可能性があります。 602 """ 603 if not self._enabled: 604 return False 605 606 if not content: 607 raise EpisodicMemoryError("'content' is required") 608 609 try: 610 from graphiti_core.nodes import EpisodeType 611 612 # Normalize group_id 613 normalized_group_id = normalize_group_id(group_id) 614 615 # Warn if episode body is too large 616 body_size = len(content) 617 if body_size > self.MAX_EPISODE_BODY_SIZE: 618 logger.warning( 619 f"Episode body is large ({body_size} chars, recommended max: " 620 f"{self.MAX_EPISODE_BODY_SIZE}). Consider splitting into smaller episodes." 621 ) 622 623 # エピソード名を決定 624 actual_name = name 625 if not actual_name: 626 actual_name = content[:50] + "..." if len(content) > 50 else content 627 628 # Add episode to Graphiti 629 # NOTE: update_communities=False to avoid graphiti-core bugs 630 await self._graphiti.add_episode( 631 group_id=normalized_group_id, 632 name=actual_name, 633 episode_body=content, 634 source=EpisodeType.message, 635 source_description=f"sdk-client-{user_id}" if user_id else "sdk-client", 636 reference_time=datetime.now(timezone.utc), 637 update_communities=False, 638 ) 639 640 logger.debug( 641 f"Episode added - group_id={normalized_group_id}" 642 ) 643 return True 644 645 except Exception as e: 646 logger.error(f"Failed to add episode: {e}") 647 raise EpisodicMemoryError(f"Failed to add episode: {e}") from e
エピソードを追加
テキストをエピソードとして追加します。
Args: group_id: グループID(FalkorDB互換形式に自動正規化) content: エピソード内容(テキスト) name: エピソード名(省略時は自動生成) user_id: ユーザーID(ソース記述に使用)
Returns: 成功した場合True
Raises: EpisodicMemoryError: 追加に失敗した場合
Example:
await client.add_episode( ... group_id="user-123", ... content="ユーザーはPythonに興味があります", ... name="user_interest", ... )
Note: episode_bodyが100,000文字を超える場合は警告を出力します。 大きなエピソードはパフォーマンスに影響する可能性があります。
649 async def search( 650 self, 651 query: str, 652 group_ids: Optional[List[str]] = None, 653 max_facts: int = 10, 654 ) -> Dict[str, Any]: 655 """ 656 エピソード記憶を検索 657 658 Args: 659 query: 検索クエリ 660 group_ids: 検索対象グループIDリスト(省略時は全グループ) 661 max_facts: 結果数上限(デフォルト: 10) 662 663 Returns: 664 Dict[str, Any]: 検索結果 665 - success: bool 666 - data: { 667 "query": str, 668 "results": [{"fact": str, "created_at": str, "uuid": str}], 669 "total_found": int 670 } 671 - error: str, error_code: str (失敗時) 672 673 Example: 674 >>> result = await client.search( 675 ... query="Pythonに関する情報", 676 ... group_ids=["user-123"], 677 ... max_facts=5, 678 ... ) 679 >>> if result["success"]: 680 ... for item in result["data"]["results"]: 681 ... print(item["fact"]) 682 683 Raises: 684 EpisodicMemoryError: 検索に失敗した場合 685 """ 686 if not self._enabled: 687 return { 688 "success": False, 689 "error": "Episodic memory is not enabled", 690 "error_code": "DISABLED", 691 } 692 693 try: 694 # Normalize group_ids 695 normalized_group_ids = None 696 if group_ids: 697 normalized_group_ids = [normalize_group_id(gid) for gid in group_ids] 698 699 # Search 700 edges = await self._graphiti.search( 701 query=query, 702 group_ids=normalized_group_ids, 703 num_results=max_facts, 704 ) 705 706 # Convert to fact dicts 707 facts = [] 708 for edge in edges: 709 fact = { 710 "fact": edge.fact if hasattr(edge, 'fact') else str(edge), 711 "created_at": edge.created_at.isoformat() if hasattr(edge, 'created_at') else None, 712 "uuid": edge.uuid if hasattr(edge, 'uuid') else None, 713 } 714 facts.append(fact) 715 716 logger.debug(f"Search returned {len(facts)} facts") 717 return { 718 "success": True, 719 "data": { 720 "query": query, 721 "results": facts, 722 "total_found": len(facts), 723 }, 724 } 725 726 except Exception as e: 727 logger.error(f"Failed to search: {e}") 728 return { 729 "success": False, 730 "error": f"Search error: {str(e)}", 731 "error_code": "EPISODIC_SEARCH_ERROR", 732 }
エピソード記憶を検索
Args: query: 検索クエリ group_ids: 検索対象グループIDリスト(省略時は全グループ) max_facts: 結果数上限(デフォルト: 10)
Returns: Dict[str, Any]: 検索結果 - success: bool - data: { "query": str, "results": [{"fact": str, "created_at": str, "uuid": str}], "total_found": int } - error: str, error_code: str (失敗時)
Example:
result = await client.search( ... query="Pythonに関する情報", ... group_ids=["user-123"], ... max_facts=5, ... ) if result["success"]: ... for item in result["data"]["results"]: ... print(item["fact"])
Raises: EpisodicMemoryError: 検索に失敗した場合
734 async def get_recent_episodes( 735 self, 736 group_id: str, 737 last_n: int = 5, 738 ) -> List[Any]: 739 """ 740 最近のエピソードを取得 741 742 Args: 743 group_id: グループID 744 last_n: 取得するエピソード数 745 746 Returns: 747 エピソードリスト 748 """ 749 if not self._enabled: 750 return [] 751 752 try: 753 normalized_group_id = normalize_group_id(group_id) 754 755 episodes = await self._graphiti.retrieve_episodes( 756 group_ids=[normalized_group_id], 757 last_n=last_n, 758 reference_time=datetime.now(timezone.utc), 759 ) 760 761 logger.debug( 762 f"Retrieved {len(episodes)} recent episodes for group_id={group_id}" 763 ) 764 return episodes 765 766 except Exception as e: 767 logger.error(f"Failed to get recent episodes: {e}") 768 raise EpisodicMemoryError( 769 f"Failed to get recent episodes: {e}" 770 ) from e
最近のエピソードを取得
Args: group_id: グループID last_n: 取得するエピソード数
Returns: エピソードリスト
772 def format_facts_for_context( 773 self, 774 facts: List[Dict[str, Any]], 775 max_facts: int = 5, 776 title: str = "### Episodic Memory", 777 ) -> str: 778 """ 779 ファクトをLLMコンテキスト用にフォーマット 780 781 Args: 782 facts: ファクトリスト 783 max_facts: 最大ファクト数 784 title: セクションタイトル 785 786 Returns: 787 フォーマットされた文字列 788 """ 789 if not facts: 790 return "" 791 792 lines = [title] 793 794 for i, fact in enumerate(facts[:max_facts], 1): 795 fact_text = fact.get("fact", "") 796 created_at = fact.get("created_at") 797 798 if created_at: 799 try: 800 dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) 801 time_str = dt.strftime("%Y-%m-%d %H:%M") 802 except: 803 time_str = created_at 804 lines.append(f"{i}. [{time_str}] {fact_text}") 805 else: 806 lines.append(f"{i}. {fact_text}") 807 808 return "\n".join(lines)
ファクトをLLMコンテキスト用にフォーマット
Args: facts: ファクトリスト max_facts: 最大ファクト数 title: セクションタイトル
Returns: フォーマットされた文字列
810 async def close(self) -> None: 811 """コネクションとクライアントリソースをクローズ 812 813 Note: 814 FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを 815 すべてクローズします。 816 """ 817 errors = [] 818 819 # Close FalkorDB driver 820 try: 821 if self._driver and hasattr(self._driver, 'close'): 822 await self._driver.close() 823 self._driver = None 824 except Exception as e: 825 errors.append(f"driver: {e}") 826 827 # Close OpenAI/Azure async client 828 try: 829 if self._openai_client and hasattr(self._openai_client, 'close'): 830 await self._openai_client.close() 831 self._openai_client = None 832 except Exception as e: 833 errors.append(f"openai_client: {e}") 834 835 # Note: LLM client, embedder, reranker typically don't have close methods 836 # but we clear references to allow garbage collection 837 self._llm_client = None 838 self._embedder = None 839 self._reranker = None 840 self._graphiti = None 841 self._enabled = False 842 843 if errors: 844 logger.warning(f"Errors during close: {', '.join(errors)}") 845 else: 846 logger.debug("EpisodicMemoryClient closed successfully")
コネクションとクライアントリソースをクローズ
Note: FalkorDBドライバー、LLMクライアント、Embedder、Rerankerを すべてクローズします。
153@dataclass 154class EpisodicMemoryConfig: 155 """Episodic Memory設定 156 157 FalkorDB/Graphitiベースのエピソード記憶システムの設定。 158 159 Attributes: 160 llm_config: LLMプロバイダー設定 161 embedder_config: Embedderプロバイダー設定 162 host: FalkorDBホスト(デフォルト: localhost) 163 port: FalkorDBポート(デフォルト: 6379) 164 database: データベース名(デフォルト: graphiti) 165 166 Example: 167 >>> from agenticstar_platform.memory import ( 168 ... EpisodicMemoryConfig, EpisodicMemoryClient, LLMProviderConfig 169 ... ) 170 >>> 171 >>> config = EpisodicMemoryConfig( 172 ... llm_config=LLMProviderConfig( 173 ... model="azure/gpt-4", 174 ... api_key="your-key", 175 ... base_url="https://your-resource.openai.azure.com", 176 ... ), 177 ... embedder_config=LLMProviderConfig( 178 ... model="azure/text-embedding-3-small", 179 ... api_key="your-key", 180 ... base_url="https://your-resource.openai.azure.com", 181 ... ), 182 ... host="localhost", 183 ... port=6379, 184 ... ) 185 >>> client = EpisodicMemoryClient(config) 186 """ 187 llm_config: LLMProviderConfig 188 embedder_config: LLMProviderConfig 189 # FalkorDB connection 190 host: str = "localhost" 191 port: int = 6379 192 database: str = "graphiti" 193 194 @classmethod 195 def from_dict(cls, data: Dict[str, Any]) -> "EpisodicMemoryConfig": 196 """辞書からEpisodicMemoryConfigを作成 197 198 Args: 199 data: 設定辞書 200 201 Returns: 202 EpisodicMemoryConfig instance 203 204 Example: 205 >>> config = EpisodicMemoryConfig.from_dict({ 206 ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, 207 ... "embedder": {"model": "azure/text-embedding-3-small", ...}, 208 ... "host": "localhost", 209 ... "port": 6379, 210 ... }) 211 """ 212 # LLM config 213 llm_data = data.get("llm") or data.get("llm_config") or {} 214 llm_config = LLMProviderConfig( 215 model=llm_data.get("model", ""), 216 api_key=llm_data.get("api_key", ""), 217 base_url=llm_data.get("base_url"), 218 api_version=llm_data.get("api_version"), 219 ) 220 221 # Embedder config 222 embedder_data = data.get("embedder") or data.get("embedder_config") or {} 223 embedder_config = LLMProviderConfig( 224 model=embedder_data.get("model", ""), 225 api_key=embedder_data.get("api_key", ""), 226 base_url=embedder_data.get("base_url"), 227 api_version=embedder_data.get("api_version"), 228 ) 229 230 # FalkorDB connection 231 host = data.get("host", "localhost") 232 port = data.get("port", 6379) 233 database = data.get("database", "graphiti") 234 235 return cls( 236 llm_config=llm_config, 237 embedder_config=embedder_config, 238 host=host, 239 port=port, 240 database=database, 241 )
Episodic Memory設定
FalkorDB/Graphitiベースのエピソード記憶システムの設定。
Attributes: llm_config: LLMプロバイダー設定 embedder_config: Embedderプロバイダー設定 host: FalkorDBホスト(デフォルト: localhost) port: FalkorDBポート(デフォルト: 6379) database: データベース名(デフォルト: graphiti)
Example:
from agenticstar_platform.memory import ( ... EpisodicMemoryConfig, EpisodicMemoryClient, LLMProviderConfig ... )
config = EpisodicMemoryConfig( ... llm_config=LLMProviderConfig( ... model="azure/gpt-4", ... api_key="your-key", ... base_url="https://your-resource.openai.azure.com", ... ), ... embedder_config=LLMProviderConfig( ... model="azure/text-embedding-3-small", ... api_key="your-key", ... base_url="https://your-resource.openai.azure.com", ... ), ... host="localhost", ... port=6379, ... ) client = EpisodicMemoryClient(config)
194 @classmethod 195 def from_dict(cls, data: Dict[str, Any]) -> "EpisodicMemoryConfig": 196 """辞書からEpisodicMemoryConfigを作成 197 198 Args: 199 data: 設定辞書 200 201 Returns: 202 EpisodicMemoryConfig instance 203 204 Example: 205 >>> config = EpisodicMemoryConfig.from_dict({ 206 ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, 207 ... "embedder": {"model": "azure/text-embedding-3-small", ...}, 208 ... "host": "localhost", 209 ... "port": 6379, 210 ... }) 211 """ 212 # LLM config 213 llm_data = data.get("llm") or data.get("llm_config") or {} 214 llm_config = LLMProviderConfig( 215 model=llm_data.get("model", ""), 216 api_key=llm_data.get("api_key", ""), 217 base_url=llm_data.get("base_url"), 218 api_version=llm_data.get("api_version"), 219 ) 220 221 # Embedder config 222 embedder_data = data.get("embedder") or data.get("embedder_config") or {} 223 embedder_config = LLMProviderConfig( 224 model=embedder_data.get("model", ""), 225 api_key=embedder_data.get("api_key", ""), 226 base_url=embedder_data.get("base_url"), 227 api_version=embedder_data.get("api_version"), 228 ) 229 230 # FalkorDB connection 231 host = data.get("host", "localhost") 232 port = data.get("port", 6379) 233 database = data.get("database", "graphiti") 234 235 return cls( 236 llm_config=llm_config, 237 embedder_config=embedder_config, 238 host=host, 239 port=port, 240 database=database, 241 )
辞書からEpisodicMemoryConfigを作成
Args: data: 設定辞書
Returns: EpisodicMemoryConfig instance
Example:
config = EpisodicMemoryConfig.from_dict({ ... "llm": {"model": "azure/gpt-4", "api_key": "...", "base_url": "..."}, ... "embedder": {"model": "azure/text-embedding-3-small", ...}, ... "host": "localhost", ... "port": 6379, ... })
104@dataclass 105class LLMProviderConfig: 106 """LLMプロバイダー設定 107 108 マルチプロバイダー対応: 109 - azure: Azure OpenAI 110 - openai: OpenAI 111 - anthropic: Anthropic Claude 112 - bedrock: AWS Bedrock 113 - gemini: Google Gemini 114 - groq: Groq 115 116 Example: 117 >>> # 辞書から作成 118 >>> config = LLMProviderConfig.from_dict({ 119 ... "model": "azure/gpt-4", 120 ... "api_key": "your-api-key", 121 ... "base_url": "https://your-resource.openai.azure.com/", 122 ... "api_version": "2024-02-15-preview" 123 ... }) 124 125 >>> # TOMLファイルから作成 126 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 127 """ 128 model: str # "azure/gpt-4" or "openai/gpt-4" format 129 api_key: str 130 base_url: Optional[str] = None # For Azure 131 api_version: Optional[str] = None # For Azure 132 # AWS Bedrock specific 133 aws_access_key_id: Optional[str] = None 134 aws_secret_access_key: Optional[str] = None 135 aws_region_name: Optional[str] = None 136 137 @classmethod 138 def from_dict(cls, data: Dict[str, Any]) -> "LLMProviderConfig": 139 """辞書からLLMProviderConfigを作成 140 141 Args: 142 data: 設定辞書 143 144 Returns: 145 LLMProviderConfig instance 146 147 Example: 148 >>> config = LLMProviderConfig.from_dict({ 149 ... "model": "azure/gpt-4", 150 ... "api_key": "your-api-key", 151 ... "base_url": "https://your-resource.openai.azure.com/", 152 ... "api_version": "2024-02-15-preview" 153 ... }) 154 """ 155 return cls( 156 model=data["model"], 157 api_key=data["api_key"], 158 base_url=data.get("base_url"), 159 api_version=data.get("api_version"), 160 aws_access_key_id=data.get("aws_access_key_id"), 161 aws_secret_access_key=data.get("aws_secret_access_key"), 162 aws_region_name=data.get("aws_region_name"), 163 ) 164 165 @classmethod 166 def from_toml(cls, toml_path: str, section: str) -> "LLMProviderConfig": 167 """TOMLファイルからLLMProviderConfigを作成 168 169 Args: 170 toml_path: TOMLファイルパス 171 section: セクション名(ドット区切りでネスト対応) 172 173 Returns: 174 LLMProviderConfig instance 175 176 Example: 177 >>> # config.toml の [memory.llm] セクションを読み込み 178 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 179 180 >>> # embedder設定の読み込み 181 >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder") 182 """ 183 if not TOML_AVAILABLE: 184 raise ImportError("toml library is required: pip install toml") 185 186 with open(toml_path, "r") as f: 187 toml_data = toml.load(f) 188 189 # ネストされたセクションをドット区切りで辿る 190 data = toml_data 191 for key in section.split("."): 192 data = data[key] 193 194 return cls.from_dict(data) 195 196 @classmethod 197 def from_env(cls, prefix: str = "LLM_") -> "LLMProviderConfig": 198 """環境変数からLLMProviderConfigを作成 199 200 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 201 202 Args: 203 prefix: 環境変数名のプレフィックス(デフォルト: "LLM_") 204 205 Returns: 206 LLMProviderConfig instance 207 208 Raises: 209 ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合 210 211 Example: 212 >>> # 環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.) 213 >>> config = LLMProviderConfig.from_env() 214 215 >>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) 216 >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") 217 218 >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) 219 >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_") 220 221 Environment Variables: 222 {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" 223 {prefix}API_KEY: APIキー(必須) 224 {prefix}BASE_URL: ベースURL(Azure等で必要) 225 {prefix}API_VERSION: APIバージョン(Azure等で必要) 226 {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 227 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 228 {prefix}AWS_REGION_NAME: AWS Bedrock用 229 """ 230 def get_env(key: str) -> Optional[str]: 231 return os.environ.get(f"{prefix}{key}") or None 232 233 model = get_env("MODEL") 234 api_key = get_env("API_KEY") 235 236 if not model: 237 raise ValueError(f"環境変数 {prefix}MODEL が設定されていません") 238 if not api_key: 239 raise ValueError(f"環境変数 {prefix}API_KEY が設定されていません") 240 241 return cls( 242 model=model, 243 api_key=api_key, 244 base_url=get_env("BASE_URL"), 245 api_version=get_env("API_VERSION"), 246 aws_access_key_id=get_env("AWS_ACCESS_KEY_ID"), 247 aws_secret_access_key=get_env("AWS_SECRET_ACCESS_KEY"), 248 aws_region_name=get_env("AWS_REGION_NAME"), 249 )
LLMプロバイダー設定
マルチプロバイダー対応:
- azure: Azure OpenAI
- openai: OpenAI
- anthropic: Anthropic Claude
- bedrock: AWS Bedrock
- gemini: Google Gemini
- groq: Groq
Example:
辞書から作成
config = LLMProviderConfig.from_dict({ ... "model": "azure/gpt-4", ... "api_key": "your-api-key", ... "base_url": "https://your-resource.openai.azure.com/", ... "api_version": "2024-02-15-preview" ... })
>>> # TOMLファイルから作成 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm")
137 @classmethod 138 def from_dict(cls, data: Dict[str, Any]) -> "LLMProviderConfig": 139 """辞書からLLMProviderConfigを作成 140 141 Args: 142 data: 設定辞書 143 144 Returns: 145 LLMProviderConfig instance 146 147 Example: 148 >>> config = LLMProviderConfig.from_dict({ 149 ... "model": "azure/gpt-4", 150 ... "api_key": "your-api-key", 151 ... "base_url": "https://your-resource.openai.azure.com/", 152 ... "api_version": "2024-02-15-preview" 153 ... }) 154 """ 155 return cls( 156 model=data["model"], 157 api_key=data["api_key"], 158 base_url=data.get("base_url"), 159 api_version=data.get("api_version"), 160 aws_access_key_id=data.get("aws_access_key_id"), 161 aws_secret_access_key=data.get("aws_secret_access_key"), 162 aws_region_name=data.get("aws_region_name"), 163 )
辞書からLLMProviderConfigを作成
Args: data: 設定辞書
Returns: LLMProviderConfig instance
Example:
config = LLMProviderConfig.from_dict({ ... "model": "azure/gpt-4", ... "api_key": "your-api-key", ... "base_url": "https://your-resource.openai.azure.com/", ... "api_version": "2024-02-15-preview" ... })
165 @classmethod 166 def from_toml(cls, toml_path: str, section: str) -> "LLMProviderConfig": 167 """TOMLファイルからLLMProviderConfigを作成 168 169 Args: 170 toml_path: TOMLファイルパス 171 section: セクション名(ドット区切りでネスト対応) 172 173 Returns: 174 LLMProviderConfig instance 175 176 Example: 177 >>> # config.toml の [memory.llm] セクションを読み込み 178 >>> config = LLMProviderConfig.from_toml("config.toml", section="memory.llm") 179 180 >>> # embedder設定の読み込み 181 >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder") 182 """ 183 if not TOML_AVAILABLE: 184 raise ImportError("toml library is required: pip install toml") 185 186 with open(toml_path, "r") as f: 187 toml_data = toml.load(f) 188 189 # ネストされたセクションをドット区切りで辿る 190 data = toml_data 191 for key in section.split("."): 192 data = data[key] 193 194 return cls.from_dict(data)
TOMLファイルからLLMProviderConfigを作成
Args: toml_path: TOMLファイルパス section: セクション名(ドット区切りでネスト対応)
Returns: LLMProviderConfig instance
Example:
config.toml の [memory.llm] セクションを読み込み
config = LLMProviderConfig.from_toml("config.toml", section="memory.llm")
>>> # embedder設定の読み込み >>> embedder = LLMProviderConfig.from_toml("config.toml", section="memory.embedder")
196 @classmethod 197 def from_env(cls, prefix: str = "LLM_") -> "LLMProviderConfig": 198 """環境変数からLLMProviderConfigを作成 199 200 環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。 201 202 Args: 203 prefix: 環境変数名のプレフィックス(デフォルト: "LLM_") 204 205 Returns: 206 LLMProviderConfig instance 207 208 Raises: 209 ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合 210 211 Example: 212 >>> # 環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.) 213 >>> config = LLMProviderConfig.from_env() 214 215 >>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) 216 >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") 217 218 >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) 219 >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_") 220 221 Environment Variables: 222 {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" 223 {prefix}API_KEY: APIキー(必須) 224 {prefix}BASE_URL: ベースURL(Azure等で必要) 225 {prefix}API_VERSION: APIバージョン(Azure等で必要) 226 {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 227 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 228 {prefix}AWS_REGION_NAME: AWS Bedrock用 229 """ 230 def get_env(key: str) -> Optional[str]: 231 return os.environ.get(f"{prefix}{key}") or None 232 233 model = get_env("MODEL") 234 api_key = get_env("API_KEY") 235 236 if not model: 237 raise ValueError(f"環境変数 {prefix}MODEL が設定されていません") 238 if not api_key: 239 raise ValueError(f"環境変数 {prefix}API_KEY が設定されていません") 240 241 return cls( 242 model=model, 243 api_key=api_key, 244 base_url=get_env("BASE_URL"), 245 api_version=get_env("API_VERSION"), 246 aws_access_key_id=get_env("AWS_ACCESS_KEY_ID"), 247 aws_secret_access_key=get_env("AWS_SECRET_ACCESS_KEY"), 248 aws_region_name=get_env("AWS_REGION_NAME"), 249 )
環境変数からLLMProviderConfigを作成
環境変数名は {prefix}{FIELD_NAME} の形式で読み込みます。
Args: prefix: 環境変数名のプレフィックス(デフォルト: "LLM_")
Returns: LLMProviderConfig instance
Raises: ValueError: 必須の環境変数(MODEL, API_KEY)が設定されていない場合
Example:
環境変数から作成(LLM_MODEL, LLM_API_KEY, etc.)
config = LLMProviderConfig.from_env()
>>> # Memory用LLM設定(MEMORY_LLM_MODEL, MEMORY_LLM_API_KEY, etc.) >>> config = LLMProviderConfig.from_env(prefix="MEMORY_LLM_") >>> # Embedder設定(EMBEDDER_MODEL, EMBEDDER_API_KEY, etc.) >>> embedder = LLMProviderConfig.from_env(prefix="EMBEDDER_")Environment Variables: {prefix}MODEL: モデル名(必須)例: "azure/gpt-4", "openai/gpt-4" {prefix}API_KEY: APIキー(必須) {prefix}BASE_URL: ベースURL(Azure等で必要) {prefix}API_VERSION: APIバージョン(Azure等で必要) {prefix}AWS_ACCESS_KEY_ID: AWS Bedrock用 {prefix}AWS_SECRET_ACCESS_KEY: AWS Bedrock用 {prefix}AWS_REGION_NAME: AWS Bedrock用
86def normalize_group_id(group_id: str) -> str: 87 """ 88 group_idをFalkorDB互換形式に正規化 89 90 FalkorDBの制約: 91 - 英数字、ハイフン、アンダースコアのみ使用可能 92 - 日本語等の特殊文字はハッシュ化 93 94 Args: 95 group_id: 元のgroup_id(日本語、スペース等を含む可能性あり) 96 97 Returns: 98 正規化されたgroup_id 99 100 Examples: 101 "org-ソリューションエンジニアリング本部" → "org-a1b2c3d4e5f6" 102 "role-プロダクトマネージャー" → "role-9g8h7i6j5k4l" 103 "user-12345" → "user-12345" (変更なし) 104 """ 105 # Already valid? (alphanumeric, dash, underscore only) 106 if re.match(r'^[a-zA-Z0-9_-]+$', group_id): 107 return group_id 108 109 # Extract prefix (org-/role-/user-/global-) 110 prefix_match = re.match(r'^(org-|role-|user-|global-)', group_id) 111 if prefix_match: 112 prefix = prefix_match.group(1) 113 suffix = group_id[len(prefix):] 114 else: 115 prefix = "" 116 suffix = group_id 117 118 # Hash the suffix if it contains non-ASCII or special chars 119 if not re.match(r'^[a-zA-Z0-9_-]+$', suffix): 120 # SHA256 hash (first 12 chars for readability) 121 hash_hex = hashlib.sha256(suffix.encode('utf-8')).hexdigest()[:12] 122 normalized = f"{prefix}{hash_hex}" 123 logger.debug(f"Normalized group_id: '{group_id}' → '{normalized}'") 124 return normalized 125 126 # Suffix is valid, just replace spaces with dashes 127 normalized = f"{prefix}{suffix}".replace(" ", "-").replace("_", "-") 128 return normalized
group_idをFalkorDB互換形式に正規化
FalkorDBの制約:
- 英数字、ハイフン、アンダースコアのみ使用可能
- 日本語等の特殊文字はハッシュ化
Args: group_id: 元のgroup_id(日本語、スペース等を含む可能性あり)
Returns: 正規化されたgroup_id
Examples: "org-ソリューションエンジニアリング本部" → "org-a1b2c3d4e5f6" "role-プロダクトマネージャー" → "role-9g8h7i6j5k4l" "user-12345" → "user-12345" (変更なし)
22class StorageProvider(str, Enum): 23 """ストレージプロバイダー種別""" 24 AZURE_BLOB = "azure_blob" 25 AWS_S3 = "aws_s3" 26 GCS = "gcs"
ストレージプロバイダー種別
97@dataclass 98class StorageConfig: 99 """共通ストレージ設定""" 100 provider: StorageProvider 101 bucket_name: str # Azure: container_name, S3: bucket, GCS: bucket 102 enabled: bool = True 103 # 共通オプション 104 max_file_size: int = 5 * 1024 * 1024 * 1024 # 5GB 105 auto_create_bucket: bool = False 106 prefix: str = "" # オブジェクト名プレフィックス 107 custom_domain: Optional[str] = None 108 # 接続設定 109 connection_timeout: int = 30 110 read_timeout: int = 300 # 大容量ファイル対応
共通ストレージ設定
113@dataclass 114class AzureBlobConfig: 115 """Azure Blob Storage設定 116 117 Example: 118 >>> # 辞書から作成 119 >>> config = AzureBlobConfig.from_dict({ 120 ... "bucket_name": "my-container", 121 ... "connection_string": "DefaultEndpointsProtocol=https;...", 122 ... "prefix": "uploads/" 123 ... }) 124 >>> 125 >>> # クライアント初期化 126 >>> client = AzureBlobStorageClient(config) 127 128 Note: 129 connection_stringはrepr=Falseでログ出力から除外されます。 130 """ 131 bucket_name: str # container_name 132 connection_string: str = field(repr=False) # シークレット: reprから除外 133 # 共通オプション 134 enabled: bool = True 135 max_file_size: int = 5 * 1024 * 1024 * 1024 # 5GB 136 auto_create_bucket: bool = False 137 prefix: str = "" 138 custom_domain: Optional[str] = None 139 connection_timeout: int = 30 140 read_timeout: int = 300 # 大容量ファイル対応 141 # Azure固有オプション 142 max_block_size: int = 8 * 1024 * 1024 # 8MB (5GBファイルで625ブロック) 143 max_single_put_size: int = 256 * 1024 * 1024 # 256MB 144 145 @property 146 def provider(self) -> StorageProvider: 147 return StorageProvider.AZURE_BLOB 148 149 def __str__(self) -> str: 150 return f"AzureBlobConfig(bucket_name={self.bucket_name!r}, connection_string=***)" 151 152 @classmethod 153 def from_dict(cls, data: Dict[str, Any]) -> "AzureBlobConfig": 154 """辞書からAzureBlobConfigを作成 155 156 Args: 157 data: 設定辞書 158 159 Returns: 160 AzureBlobConfig instance 161 162 Example: 163 >>> config = AzureBlobConfig.from_dict({ 164 ... "bucket_name": "my-container", 165 ... "connection_string": "DefaultEndpointsProtocol=https;..." 166 ... }) 167 """ 168 return cls( 169 bucket_name=data["bucket_name"], 170 connection_string=data["connection_string"], 171 enabled=data.get("enabled", True), 172 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 173 auto_create_bucket=data.get("auto_create_bucket", False), 174 prefix=data.get("prefix", ""), 175 custom_domain=data.get("custom_domain"), 176 connection_timeout=data.get("connection_timeout", 30), 177 read_timeout=data.get("read_timeout", 300), 178 max_block_size=data.get("max_block_size", 8 * 1024 * 1024), 179 max_single_put_size=data.get("max_single_put_size", 256 * 1024 * 1024), 180 )
Azure Blob Storage設定
Example:
辞書から作成
config = AzureBlobConfig.from_dict({ ... "bucket_name": "my-container", ... "connection_string": "DefaultEndpointsProtocol=https;...", ... "prefix": "uploads/" ... })
クライアント初期化
client = AzureBlobStorageClient(config)
Note: connection_stringはrepr=Falseでログ出力から除外されます。
152 @classmethod 153 def from_dict(cls, data: Dict[str, Any]) -> "AzureBlobConfig": 154 """辞書からAzureBlobConfigを作成 155 156 Args: 157 data: 設定辞書 158 159 Returns: 160 AzureBlobConfig instance 161 162 Example: 163 >>> config = AzureBlobConfig.from_dict({ 164 ... "bucket_name": "my-container", 165 ... "connection_string": "DefaultEndpointsProtocol=https;..." 166 ... }) 167 """ 168 return cls( 169 bucket_name=data["bucket_name"], 170 connection_string=data["connection_string"], 171 enabled=data.get("enabled", True), 172 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 173 auto_create_bucket=data.get("auto_create_bucket", False), 174 prefix=data.get("prefix", ""), 175 custom_domain=data.get("custom_domain"), 176 connection_timeout=data.get("connection_timeout", 30), 177 read_timeout=data.get("read_timeout", 300), 178 max_block_size=data.get("max_block_size", 8 * 1024 * 1024), 179 max_single_put_size=data.get("max_single_put_size", 256 * 1024 * 1024), 180 )
辞書からAzureBlobConfigを作成
Args: data: 設定辞書
Returns: AzureBlobConfig instance
Example:
config = AzureBlobConfig.from_dict({ ... "bucket_name": "my-container", ... "connection_string": "DefaultEndpointsProtocol=https;..." ... })
183@dataclass 184class S3Config: 185 """AWS S3設定 186 187 Example: 188 >>> # 辞書から作成 189 >>> config = S3Config.from_dict({ 190 ... "bucket_name": "my-bucket", 191 ... "aws_access_key_id": "AKIAXXXXXXX", 192 ... "aws_secret_access_key": "secretkey", 193 ... "region_name": "ap-northeast-1" 194 ... }) 195 >>> 196 >>> # クライアント初期化 197 >>> client = S3StorageClient(config) 198 199 Note: 200 aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。 201 """ 202 bucket_name: str 203 aws_access_key_id: str = field(repr=False) # シークレット: reprから除外 204 aws_secret_access_key: str = field(repr=False) # シークレット: reprから除外 205 region_name: str = "us-east-1" 206 # 共通オプション 207 enabled: bool = True 208 max_file_size: int = 5 * 1024 * 1024 * 1024 # 5GB 209 auto_create_bucket: bool = False 210 prefix: str = "" 211 custom_domain: Optional[str] = None 212 connection_timeout: int = 30 213 read_timeout: int = 300 # 大容量ファイル対応 214 # S3固有オプション 215 endpoint_url: Optional[str] = None # MinIO等のS3互換ストレージ用 216 217 @property 218 def provider(self) -> StorageProvider: 219 return StorageProvider.AWS_S3 220 221 def __str__(self) -> str: 222 return f"S3Config(bucket_name={self.bucket_name!r}, region_name={self.region_name!r}, credentials=***)" 223 224 @classmethod 225 def from_dict(cls, data: Dict[str, Any]) -> "S3Config": 226 """辞書からS3Configを作成 227 228 Args: 229 data: 設定辞書 230 231 Returns: 232 S3Config instance 233 234 Example: 235 >>> config = S3Config.from_dict({ 236 ... "bucket_name": "my-bucket", 237 ... "aws_access_key_id": "AKIAXXXXXXX", 238 ... "aws_secret_access_key": "secretkey", 239 ... "region_name": "ap-northeast-1" 240 ... }) 241 """ 242 return cls( 243 bucket_name=data["bucket_name"], 244 aws_access_key_id=data["aws_access_key_id"], 245 aws_secret_access_key=data["aws_secret_access_key"], 246 region_name=data.get("region_name", "us-east-1"), 247 enabled=data.get("enabled", True), 248 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 249 auto_create_bucket=data.get("auto_create_bucket", False), 250 prefix=data.get("prefix", ""), 251 custom_domain=data.get("custom_domain"), 252 connection_timeout=data.get("connection_timeout", 30), 253 read_timeout=data.get("read_timeout", 300), 254 endpoint_url=data.get("endpoint_url"), 255 )
AWS S3設定
Example:
辞書から作成
config = S3Config.from_dict({ ... "bucket_name": "my-bucket", ... "aws_access_key_id": "AKIAXXXXXXX", ... "aws_secret_access_key": "secretkey", ... "region_name": "ap-northeast-1" ... })
クライアント初期化
client = S3StorageClient(config)
Note: aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。
224 @classmethod 225 def from_dict(cls, data: Dict[str, Any]) -> "S3Config": 226 """辞書からS3Configを作成 227 228 Args: 229 data: 設定辞書 230 231 Returns: 232 S3Config instance 233 234 Example: 235 >>> config = S3Config.from_dict({ 236 ... "bucket_name": "my-bucket", 237 ... "aws_access_key_id": "AKIAXXXXXXX", 238 ... "aws_secret_access_key": "secretkey", 239 ... "region_name": "ap-northeast-1" 240 ... }) 241 """ 242 return cls( 243 bucket_name=data["bucket_name"], 244 aws_access_key_id=data["aws_access_key_id"], 245 aws_secret_access_key=data["aws_secret_access_key"], 246 region_name=data.get("region_name", "us-east-1"), 247 enabled=data.get("enabled", True), 248 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 249 auto_create_bucket=data.get("auto_create_bucket", False), 250 prefix=data.get("prefix", ""), 251 custom_domain=data.get("custom_domain"), 252 connection_timeout=data.get("connection_timeout", 30), 253 read_timeout=data.get("read_timeout", 300), 254 endpoint_url=data.get("endpoint_url"), 255 )
辞書からS3Configを作成
Args: data: 設定辞書
Returns: S3Config instance
Example:
config = S3Config.from_dict({ ... "bucket_name": "my-bucket", ... "aws_access_key_id": "AKIAXXXXXXX", ... "aws_secret_access_key": "secretkey", ... "region_name": "ap-northeast-1" ... })
258@dataclass 259class GCSConfig: 260 """Google Cloud Storage設定 261 262 Example: 263 >>> # 辞書から作成 264 >>> config = GCSConfig.from_dict({ 265 ... "bucket_name": "my-bucket", 266 ... "project_id": "my-project", 267 ... "credentials_path": "/path/to/service-account.json" 268 ... }) 269 >>> 270 >>> # クライアント初期化 271 >>> client = GCSStorageClient(config) 272 273 Note: 274 credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。 275 """ 276 bucket_name: str 277 project_id: str 278 # 共通オプション 279 enabled: bool = True 280 max_file_size: int = 5 * 1024 * 1024 * 1024 # 5GB 281 auto_create_bucket: bool = False 282 prefix: str = "" 283 custom_domain: Optional[str] = None 284 connection_timeout: int = 30 285 read_timeout: int = 300 # 大容量ファイル対応 286 # GCS固有オプション(シークレット: reprから除外) 287 credentials_path: Optional[str] = field(default=None, repr=False) 288 credentials_json: Optional[str] = field(default=None, repr=False) 289 credentials_base64: Optional[str] = field(default=None, repr=False) # Base64-encoded JSON (TOML-safe) 290 291 @property 292 def provider(self) -> StorageProvider: 293 return StorageProvider.GCS 294 295 def __str__(self) -> str: 296 return f"GCSConfig(bucket_name={self.bucket_name!r}, project_id={self.project_id!r}, credentials=***)" 297 298 @classmethod 299 def from_dict(cls, data: Dict[str, Any]) -> "GCSConfig": 300 """辞書からGCSConfigを作成 301 302 Args: 303 data: 設定辞書 304 305 Returns: 306 GCSConfig instance 307 308 Example: 309 >>> config = GCSConfig.from_dict({ 310 ... "bucket_name": "my-bucket", 311 ... "project_id": "my-project", 312 ... "credentials_path": "/path/to/service-account.json" 313 ... }) 314 """ 315 return cls( 316 bucket_name=data["bucket_name"], 317 project_id=data["project_id"], 318 enabled=data.get("enabled", True), 319 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 320 auto_create_bucket=data.get("auto_create_bucket", False), 321 prefix=data.get("prefix", ""), 322 custom_domain=data.get("custom_domain"), 323 connection_timeout=data.get("connection_timeout", 30), 324 read_timeout=data.get("read_timeout", 300), 325 credentials_path=data.get("credentials_path"), 326 credentials_json=data.get("credentials_json"), 327 credentials_base64=data.get("credentials_base64"), 328 )
Google Cloud Storage設定
Example:
辞書から作成
config = GCSConfig.from_dict({ ... "bucket_name": "my-bucket", ... "project_id": "my-project", ... "credentials_path": "/path/to/service-account.json" ... })
クライアント初期化
client = GCSStorageClient(config)
Note: credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。
298 @classmethod 299 def from_dict(cls, data: Dict[str, Any]) -> "GCSConfig": 300 """辞書からGCSConfigを作成 301 302 Args: 303 data: 設定辞書 304 305 Returns: 306 GCSConfig instance 307 308 Example: 309 >>> config = GCSConfig.from_dict({ 310 ... "bucket_name": "my-bucket", 311 ... "project_id": "my-project", 312 ... "credentials_path": "/path/to/service-account.json" 313 ... }) 314 """ 315 return cls( 316 bucket_name=data["bucket_name"], 317 project_id=data["project_id"], 318 enabled=data.get("enabled", True), 319 max_file_size=data.get("max_file_size", 5 * 1024 * 1024 * 1024), 320 auto_create_bucket=data.get("auto_create_bucket", False), 321 prefix=data.get("prefix", ""), 322 custom_domain=data.get("custom_domain"), 323 connection_timeout=data.get("connection_timeout", 30), 324 read_timeout=data.get("read_timeout", 300), 325 credentials_path=data.get("credentials_path"), 326 credentials_json=data.get("credentials_json"), 327 credentials_base64=data.get("credentials_base64"), 328 )
辞書からGCSConfigを作成
Args: data: 設定辞書
Returns: GCSConfig instance
Example:
config = GCSConfig.from_dict({ ... "bucket_name": "my-bucket", ... "project_id": "my-project", ... "credentials_path": "/path/to/service-account.json" ... })
39class AzureBlobStorageClient(StorageClientBase): 40 """ 41 Azure Blob Storageクライアント 42 43 Features: 44 - ファイルアップロード・ダウンロード 45 - オブジェクト一覧・削除 46 - カスタムドメイン対応 47 - 日本語ファイル名対応(URLエンコード) 48 - メタデータサニタイズ(Base64エンコード) 49 50 Example: 51 config = AzureBlobConfig( 52 bucket_name="my-container", 53 connection_string="DefaultEndpointsProtocol=https;...", 54 ) 55 client = AzureBlobStorageClient(config) 56 result = await client.upload_file("/path/to/file.txt") 57 """ 58 59 def __init__(self, config: AzureBlobConfig): 60 """ 61 Initialize Azure Blob Storage Client 62 63 Args: 64 config: AzureBlobConfig instance 65 66 Raises: 67 StorageConfigError: 設定が不正な場合 68 StorageConnectionError: 接続に失敗した場合 69 """ 70 super().__init__(config) 71 72 if not AZURE_BLOB_AVAILABLE: 73 raise StorageConfigError( 74 "azure-storage-blob package is not installed. " 75 "Install with: pip install azure-storage-blob" 76 ) 77 78 if not config.connection_string: 79 raise StorageConfigError( 80 "Azure Blob connection_string is required" 81 ) 82 83 if not config.bucket_name: 84 raise StorageConfigError( 85 "Azure Blob bucket_name (container_name) is required" 86 ) 87 88 self._azure_config: AzureBlobConfig = config 89 self._blob_service_client: Optional[BlobServiceClient] = None 90 91 try: 92 self._blob_service_client = BlobServiceClient.from_connection_string( 93 conn_str=self._get_api_key(config.connection_string), 94 max_block_size=config.max_block_size, 95 max_single_put_size=config.max_single_put_size, 96 connection_timeout=config.connection_timeout, 97 read_timeout=config.read_timeout, 98 ) 99 self._enabled = config.enabled 100 logger.info( 101 f"AzureBlobStorageClient initialized - " 102 f"container={config.bucket_name}" 103 ) 104 except Exception as e: 105 logger.error(f"Failed to initialize Azure Blob client: {e}") 106 raise StorageConnectionError( 107 f"Failed to connect to Azure Blob Storage: {e}" 108 ) from e 109 110 @property 111 def provider(self) -> StorageProvider: 112 return StorageProvider.AZURE_BLOB 113 114 @property 115 def container_name(self) -> str: 116 """コンテナ名(bucket_nameのエイリアス)""" 117 return self._azure_config.bucket_name 118 119 def _get_container_client(self): 120 """コンテナクライアントを取得""" 121 return self._blob_service_client.get_container_client(self.container_name) 122 123 def _get_blob_client(self, blob_name: str): 124 """Blobクライアントを取得""" 125 return self._blob_service_client.get_blob_client( 126 container=self.container_name, 127 blob=blob_name, 128 ) 129 130 def _encode_blob_name(self, blob_name: str) -> str: 131 """Blob名をURLエンコード(日本語対応)""" 132 try: 133 if blob_name != unquote(blob_name): 134 return blob_name 135 return quote(blob_name, safe='/') 136 except Exception: 137 return quote(blob_name, safe='/') 138 139 def _generate_public_url(self, blob_url: str) -> str: 140 """公開URLを生成(カスタムドメイン対応)""" 141 custom_domain = self._azure_config.custom_domain 142 if not custom_domain: 143 return blob_url 144 145 try: 146 parsed = urlparse(blob_url) 147 return f"https://{custom_domain}{parsed.path}" 148 except Exception as e: 149 logger.warning(f"Failed to generate custom domain URL: {e}") 150 return blob_url 151 152 async def ensure_bucket_exists(self) -> bool: 153 """コンテナの存在確認・作成""" 154 if not self._enabled: 155 return False 156 157 try: 158 container_client = self._get_container_client() 159 160 exists = await asyncio.get_event_loop().run_in_executor( 161 None, container_client.exists 162 ) 163 164 if not exists: 165 if self._azure_config.auto_create_bucket: 166 await asyncio.get_event_loop().run_in_executor( 167 None, container_client.create_container 168 ) 169 logger.info(f"Created container: {self.container_name}") 170 else: 171 logger.error( 172 f"Container {self.container_name} does not exist " 173 "and auto_create_bucket is disabled" 174 ) 175 return False 176 177 return True 178 179 except Exception as e: 180 logger.error(f"Failed to ensure container exists: {e}") 181 return False 182 183 async def upload_file( 184 self, 185 file_path: str, 186 object_name: Optional[str] = None, 187 prefix: Optional[str] = None, 188 metadata: Optional[Dict[str, Any]] = None, 189 ) -> UploadResult: 190 """ 191 ファイルをAzure Blobにアップロード 192 193 Args: 194 file_path: アップロードするファイルのローカルパス 195 object_name: Blob名(省略時はファイル名が使用される) 196 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 197 metadata: 追加メタデータ(日本語はBase64エンコードされる) 198 199 Returns: 200 UploadResult: アップロード結果 201 - success: 成功/失敗 202 - object_name: アップロード先のBlob名 203 - object_url: 公開URL 204 - file_size: ファイルサイズ(bytes) 205 - content_type: MIMEタイプ 206 207 Example: 208 >>> # 基本的なアップロード 209 >>> result = await client.upload_file("/tmp/report.pdf") 210 >>> if result.success: 211 ... print(f"Uploaded: {result.object_url}") 212 >>> 213 >>> # プレフィックス指定でフォルダ構造を作成 214 >>> result = await client.upload_file( 215 ... file_path="/tmp/image.png", 216 ... prefix="uploads/2024/01", 217 ... metadata={"author": "user-123"} 218 ... ) 219 """ 220 if not self._enabled: 221 return UploadResult( 222 success=False, 223 error="Azure Blob Storage is not enabled", 224 error_code="STORAGE_NOT_ENABLED", 225 ) 226 227 try: 228 # ファイルサイズ検証 229 self._validate_file_size(file_path) 230 231 file_path_obj = Path(file_path) 232 233 # オブジェクト名決定 234 if not object_name: 235 object_name = file_path_obj.name 236 237 # URLエンコード 238 safe_object_name = self._encode_blob_name(object_name) 239 240 # プレフィックス適用(呼び出し側が指定したprefixのみ使用) 241 if prefix: 242 full_object_name = f"{prefix}/{safe_object_name}" 243 else: 244 full_object_name = safe_object_name 245 246 # コンテナ確認 247 if not await self.ensure_bucket_exists(): 248 return UploadResult( 249 success=False, 250 error="Container is not available", 251 error_code="CONTAINER_NOT_AVAILABLE", 252 ) 253 254 # Content-Type設定 255 content_type = self._get_content_type(file_path) 256 content_settings = ContentSettings(content_type=content_type) 257 258 # メタデータサニタイズ 259 safe_metadata = self._sanitize_metadata(metadata) 260 261 # アップロード実行 262 blob_client = self._get_blob_client(full_object_name) 263 file_size = file_path_obj.stat().st_size 264 265 with open(file_path_obj, 'rb') as data: 266 await asyncio.get_event_loop().run_in_executor( 267 None, 268 lambda: blob_client.upload_blob( 269 data, 270 overwrite=True, 271 content_settings=content_settings, 272 metadata=safe_metadata if safe_metadata else None, 273 ) 274 ) 275 276 # URL生成 277 blob_url = self._generate_public_url(blob_client.url) 278 279 logger.info("Successfully uploaded file to Azure Blob") 280 281 return UploadResult( 282 success=True, 283 object_name=full_object_name, 284 object_url=blob_url, 285 file_size=file_size, 286 content_type=content_type, 287 metadata={ 288 "uploaded_at": time.time(), 289 "prefix": prefix, 290 **(safe_metadata or {}), 291 }, 292 ) 293 294 except StorageOperationError: 295 raise 296 except Exception as e: 297 logger.error(f"Failed to upload file: {e}") 298 return UploadResult( 299 success=False, 300 error=f"Upload failed: {e}", 301 error_code="UPLOAD_FAILED", 302 ) 303 304 async def download_file( 305 self, 306 object_name: str, 307 download_path: str, 308 ) -> DownloadResult: 309 """ 310 ファイルをダウンロード(ストリーミング) 311 312 大きなファイルでもメモリを圧迫しないよう、 313 チャンク単位でストリーミングダウンロードします。 314 315 Args: 316 object_name: ダウンロードするBlob名 317 download_path: ダウンロード先のローカルパス 318 319 Returns: 320 DownloadResult: ダウンロード結果 321 - success: 成功/失敗 322 - local_path: ダウンロード先パス 323 - file_size: ファイルサイズ(bytes) 324 - error: エラーメッセージ(失敗時) 325 326 Example: 327 >>> # Blobをダウンロード 328 >>> result = await client.download_file( 329 ... object_name="uploads/2024/01/report.pdf", 330 ... download_path="/tmp/downloaded_report.pdf" 331 ... ) 332 >>> if result.success: 333 ... print(f"Downloaded to: {result.local_path}") 334 ... print(f"Size: {result.file_size} bytes") 335 >>> else: 336 ... print(f"Error: {result.error}") 337 338 Note: 339 ダウンロード先ディレクトリが存在しない場合は自動作成されます。 340 """ 341 if not self._enabled: 342 return DownloadResult( 343 success=False, 344 error="Azure Blob Storage is not enabled", 345 error_code="STORAGE_NOT_ENABLED", 346 ) 347 348 try: 349 blob_client = self._get_blob_client(object_name) 350 351 # 存在確認 352 exists = await asyncio.get_event_loop().run_in_executor( 353 None, blob_client.exists 354 ) 355 if not exists: 356 return DownloadResult( 357 success=False, 358 object_name=object_name, 359 error=f"Blob not found: {object_name}", 360 error_code="BLOB_NOT_FOUND", 361 ) 362 363 # ダウンロード先ディレクトリ作成 364 download_path_obj = Path(download_path) 365 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 366 367 # ストリーミングダウンロード(メモリ効率化) 368 def _streaming_download(): 369 with open(download_path_obj, 'wb') as f: 370 download_stream = blob_client.download_blob() 371 # chunks()でチャンク単位でダウンロード 372 for chunk in download_stream.chunks(): 373 f.write(chunk) 374 375 await asyncio.get_event_loop().run_in_executor( 376 None, _streaming_download 377 ) 378 379 file_size = download_path_obj.stat().st_size 380 logger.info("Successfully downloaded blob") 381 382 return DownloadResult( 383 success=True, 384 object_name=object_name, 385 local_path=str(download_path_obj), 386 file_size=file_size, 387 ) 388 389 except Exception as e: 390 logger.error(f"Failed to download blob: {e}") 391 return DownloadResult( 392 success=False, 393 object_name=object_name, 394 error=f"Download failed: {e}", 395 error_code="DOWNLOAD_FAILED", 396 ) 397 398 async def list_objects( 399 self, 400 prefix: str = "", 401 max_results: Optional[int] = None, 402 ) -> ListResult: 403 """ 404 オブジェクト一覧を取得 405 406 指定したプレフィックスに一致するBlobの一覧を取得します。 407 408 Args: 409 prefix: プレフィックスフィルタ(フォルダパスとして機能) 410 max_results: 最大取得数(省略時は全件取得) 411 412 Returns: 413 ListResult: 一覧結果 414 - success: 成功/失敗 415 - objects: ObjectInfoのリスト 416 - count: 取得件数 417 418 Example: 419 >>> # 全オブジェクト一覧 420 >>> result = await client.list_objects() 421 >>> print(f"Total: {result.count} objects") 422 >>> 423 >>> # プレフィックスでフィルタ 424 >>> result = await client.list_objects( 425 ... prefix="uploads/2024/", 426 ... max_results=100 427 ... ) 428 >>> for obj in result.objects: 429 ... print(f"{obj.name}: {obj.size} bytes") 430 """ 431 if not self._enabled: 432 return ListResult( 433 success=False, 434 error="Azure Blob Storage is not enabled", 435 error_code="STORAGE_NOT_ENABLED", 436 ) 437 438 try: 439 container_client = self._get_container_client() 440 441 blob_list = await asyncio.get_event_loop().run_in_executor( 442 None, 443 lambda: list(container_client.list_blobs(name_starts_with=prefix)) 444 ) 445 446 objects = [] 447 for blob in blob_list: 448 if max_results and len(objects) >= max_results: 449 break 450 objects.append(ObjectInfo( 451 name=blob.name, 452 size=blob.size, 453 last_modified=blob.last_modified.isoformat() if blob.last_modified else None, 454 content_type=blob.content_settings.content_type if blob.content_settings else None, 455 )) 456 457 logger.info(f"Listed {len(objects)} blobs") 458 459 return ListResult( 460 success=True, 461 objects=objects, 462 count=len(objects), 463 prefix=prefix, 464 ) 465 466 except Exception as e: 467 logger.error(f"Failed to list blobs: {e}") 468 return ListResult( 469 success=False, 470 error=f"List failed: {e}", 471 error_code="LIST_FAILED", 472 ) 473 474 async def delete_object(self, object_name: str) -> bool: 475 """ 476 オブジェクトを削除 477 478 Args: 479 object_name: 削除するBlob名 480 481 Returns: 482 成功した場合True 483 """ 484 if not self._enabled: 485 return False 486 487 try: 488 blob_client = self._get_blob_client(object_name) 489 490 await asyncio.get_event_loop().run_in_executor( 491 None, 492 lambda: blob_client.delete_blob(delete_snapshots="include") 493 ) 494 495 logger.info("Successfully deleted blob") 496 return True 497 498 except Exception as e: 499 logger.error(f"Failed to delete blob: {e}") 500 return False 501 502 async def object_exists(self, object_name: str) -> bool: 503 """ 504 オブジェクトの存在確認 505 506 Args: 507 object_name: 確認するBlob名 508 509 Returns: 510 存在する場合True 511 """ 512 if not self._enabled: 513 return False 514 515 try: 516 blob_client = self._get_blob_client(object_name) 517 return await asyncio.get_event_loop().run_in_executor( 518 None, blob_client.exists 519 ) 520 except Exception as e: 521 logger.error(f"Failed to check blob existence: {e}") 522 return False 523 524 async def get_object_url( 525 self, 526 object_name: str, 527 expires_in: Optional[int] = None, 528 ) -> str: 529 """ 530 オブジェクトのURLを取得 531 532 Args: 533 object_name: Blob名 534 expires_in: SAS有効期限(秒)- 未実装、将来拡張用 535 536 Returns: 537 オブジェクトURL 538 """ 539 if not self._enabled: 540 return "" 541 542 blob_client = self._get_blob_client(object_name) 543 return self._generate_public_url(blob_client.url) 544 545 def _sanitize_path(self, path: str) -> str: 546 """パスをサニタイズ(パストラバーサル防止) 547 548 Args: 549 path: サニタイズするパス 550 551 Returns: 552 サニタイズされたパス(危険な文字を除去) 553 """ 554 import re 555 # パストラバーサル攻撃防止: ../ や ..\\ を除去 556 # 先頭の / も除去 557 sanitized = re.sub(r'\.\.[\\/]', '', path) 558 sanitized = re.sub(r'^[\\/]+', '', sanitized) 559 # NULL文字や制御文字を除去 560 sanitized = re.sub(r'[\x00-\x1f]', '', sanitized) 561 return sanitized 562 563 async def download_objects_by_prefix( 564 self, 565 prefix: str, 566 download_dir: str, 567 ) -> Dict[str, Any]: 568 """ 569 プレフィックスに一致するオブジェクトをすべてダウンロード 570 571 Args: 572 prefix: プレフィックス 573 download_dir: ダウンロード先ディレクトリ 574 575 Returns: 576 ダウンロード結果サマリー 577 578 Security: 579 パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 580 ダウンロード先がdownload_dir配下であることを検証します。 581 """ 582 list_result = await self.list_objects(prefix) 583 if not list_result.success: 584 return { 585 "success": False, 586 "error": list_result.error, 587 "downloaded": [], 588 "failed": [], 589 } 590 591 download_dir_obj = Path(download_dir).resolve() 592 downloaded = [] 593 failed = [] 594 skipped = [] 595 596 for obj in list_result.objects: 597 # プレフィックス除去 598 relative_path = obj.name 599 if obj.name.startswith(prefix): 600 relative_path = obj.name[len(prefix):] 601 602 # パスサニタイズ(パストラバーサル防止) 603 relative_path = self._sanitize_path(relative_path) 604 605 if not relative_path: 606 skipped.append({ 607 "object_name": obj.name, 608 "reason": "Invalid path after sanitization", 609 }) 610 continue 611 612 download_path = (download_dir_obj / relative_path).resolve() 613 614 # パストラバーサル検証: download_dir配下であることを確認 615 try: 616 download_path.relative_to(download_dir_obj) 617 except ValueError: 618 logger.warning( 619 f"Path traversal attempt detected, skipping object" 620 ) 621 skipped.append({ 622 "object_name": obj.name, 623 "reason": "Path traversal detected", 624 }) 625 continue 626 627 result = await self.download_file(obj.name, str(download_path)) 628 if result.success: 629 downloaded.append({ 630 "object_name": obj.name, 631 "local_path": result.local_path, 632 "file_size": result.file_size, 633 }) 634 else: 635 failed.append({ 636 "object_name": obj.name, 637 "error": result.error, 638 }) 639 640 return { 641 "success": True, 642 "downloaded": downloaded, 643 "failed": failed, 644 "skipped": skipped, 645 "success_count": len(downloaded), 646 "failed_count": len(failed), 647 "skipped_count": len(skipped), 648 } 649 650 async def close(self) -> None: 651 """クライアントリソースをクローズ 652 653 Note: 654 例外が発生してもリソースを確実にクリーンアップします。 655 """ 656 try: 657 if self._blob_service_client: 658 # BlobServiceClientはclose()メソッドを持たないが、 659 # 将来のバージョンで追加される可能性があるため確認 660 if hasattr(self._blob_service_client, 'close'): 661 self._blob_service_client.close() 662 except Exception as e: 663 logger.warning(f"Error closing Azure Blob client: {e}") 664 finally: 665 self._blob_service_client = None 666 self._enabled = False 667 logger.debug("AzureBlobStorageClient closed")
Azure Blob Storageクライアント
Features:
- ファイルアップロード・ダウンロード
- オブジェクト一覧・削除
- カスタムドメイン対応
- 日本語ファイル名対応(URLエンコード)
- メタデータサニタイズ(Base64エンコード)
Example: config = AzureBlobConfig( bucket_name="my-container", connection_string="DefaultEndpointsProtocol=https;...", ) client = AzureBlobStorageClient(config) result = await client.upload_file("/path/to/file.txt")
59 def __init__(self, config: AzureBlobConfig): 60 """ 61 Initialize Azure Blob Storage Client 62 63 Args: 64 config: AzureBlobConfig instance 65 66 Raises: 67 StorageConfigError: 設定が不正な場合 68 StorageConnectionError: 接続に失敗した場合 69 """ 70 super().__init__(config) 71 72 if not AZURE_BLOB_AVAILABLE: 73 raise StorageConfigError( 74 "azure-storage-blob package is not installed. " 75 "Install with: pip install azure-storage-blob" 76 ) 77 78 if not config.connection_string: 79 raise StorageConfigError( 80 "Azure Blob connection_string is required" 81 ) 82 83 if not config.bucket_name: 84 raise StorageConfigError( 85 "Azure Blob bucket_name (container_name) is required" 86 ) 87 88 self._azure_config: AzureBlobConfig = config 89 self._blob_service_client: Optional[BlobServiceClient] = None 90 91 try: 92 self._blob_service_client = BlobServiceClient.from_connection_string( 93 conn_str=self._get_api_key(config.connection_string), 94 max_block_size=config.max_block_size, 95 max_single_put_size=config.max_single_put_size, 96 connection_timeout=config.connection_timeout, 97 read_timeout=config.read_timeout, 98 ) 99 self._enabled = config.enabled 100 logger.info( 101 f"AzureBlobStorageClient initialized - " 102 f"container={config.bucket_name}" 103 ) 104 except Exception as e: 105 logger.error(f"Failed to initialize Azure Blob client: {e}") 106 raise StorageConnectionError( 107 f"Failed to connect to Azure Blob Storage: {e}" 108 ) from e
Initialize Azure Blob Storage Client
Args: config: AzureBlobConfig instance
Raises: StorageConfigError: 設定が不正な場合 StorageConnectionError: 接続に失敗した場合
114 @property 115 def container_name(self) -> str: 116 """コンテナ名(bucket_nameのエイリアス)""" 117 return self._azure_config.bucket_name
コンテナ名(bucket_nameのエイリアス)
152 async def ensure_bucket_exists(self) -> bool: 153 """コンテナの存在確認・作成""" 154 if not self._enabled: 155 return False 156 157 try: 158 container_client = self._get_container_client() 159 160 exists = await asyncio.get_event_loop().run_in_executor( 161 None, container_client.exists 162 ) 163 164 if not exists: 165 if self._azure_config.auto_create_bucket: 166 await asyncio.get_event_loop().run_in_executor( 167 None, container_client.create_container 168 ) 169 logger.info(f"Created container: {self.container_name}") 170 else: 171 logger.error( 172 f"Container {self.container_name} does not exist " 173 "and auto_create_bucket is disabled" 174 ) 175 return False 176 177 return True 178 179 except Exception as e: 180 logger.error(f"Failed to ensure container exists: {e}") 181 return False
コンテナの存在確認・作成
183 async def upload_file( 184 self, 185 file_path: str, 186 object_name: Optional[str] = None, 187 prefix: Optional[str] = None, 188 metadata: Optional[Dict[str, Any]] = None, 189 ) -> UploadResult: 190 """ 191 ファイルをAzure Blobにアップロード 192 193 Args: 194 file_path: アップロードするファイルのローカルパス 195 object_name: Blob名(省略時はファイル名が使用される) 196 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 197 metadata: 追加メタデータ(日本語はBase64エンコードされる) 198 199 Returns: 200 UploadResult: アップロード結果 201 - success: 成功/失敗 202 - object_name: アップロード先のBlob名 203 - object_url: 公開URL 204 - file_size: ファイルサイズ(bytes) 205 - content_type: MIMEタイプ 206 207 Example: 208 >>> # 基本的なアップロード 209 >>> result = await client.upload_file("/tmp/report.pdf") 210 >>> if result.success: 211 ... print(f"Uploaded: {result.object_url}") 212 >>> 213 >>> # プレフィックス指定でフォルダ構造を作成 214 >>> result = await client.upload_file( 215 ... file_path="/tmp/image.png", 216 ... prefix="uploads/2024/01", 217 ... metadata={"author": "user-123"} 218 ... ) 219 """ 220 if not self._enabled: 221 return UploadResult( 222 success=False, 223 error="Azure Blob Storage is not enabled", 224 error_code="STORAGE_NOT_ENABLED", 225 ) 226 227 try: 228 # ファイルサイズ検証 229 self._validate_file_size(file_path) 230 231 file_path_obj = Path(file_path) 232 233 # オブジェクト名決定 234 if not object_name: 235 object_name = file_path_obj.name 236 237 # URLエンコード 238 safe_object_name = self._encode_blob_name(object_name) 239 240 # プレフィックス適用(呼び出し側が指定したprefixのみ使用) 241 if prefix: 242 full_object_name = f"{prefix}/{safe_object_name}" 243 else: 244 full_object_name = safe_object_name 245 246 # コンテナ確認 247 if not await self.ensure_bucket_exists(): 248 return UploadResult( 249 success=False, 250 error="Container is not available", 251 error_code="CONTAINER_NOT_AVAILABLE", 252 ) 253 254 # Content-Type設定 255 content_type = self._get_content_type(file_path) 256 content_settings = ContentSettings(content_type=content_type) 257 258 # メタデータサニタイズ 259 safe_metadata = self._sanitize_metadata(metadata) 260 261 # アップロード実行 262 blob_client = self._get_blob_client(full_object_name) 263 file_size = file_path_obj.stat().st_size 264 265 with open(file_path_obj, 'rb') as data: 266 await asyncio.get_event_loop().run_in_executor( 267 None, 268 lambda: blob_client.upload_blob( 269 data, 270 overwrite=True, 271 content_settings=content_settings, 272 metadata=safe_metadata if safe_metadata else None, 273 ) 274 ) 275 276 # URL生成 277 blob_url = self._generate_public_url(blob_client.url) 278 279 logger.info("Successfully uploaded file to Azure Blob") 280 281 return UploadResult( 282 success=True, 283 object_name=full_object_name, 284 object_url=blob_url, 285 file_size=file_size, 286 content_type=content_type, 287 metadata={ 288 "uploaded_at": time.time(), 289 "prefix": prefix, 290 **(safe_metadata or {}), 291 }, 292 ) 293 294 except StorageOperationError: 295 raise 296 except Exception as e: 297 logger.error(f"Failed to upload file: {e}") 298 return UploadResult( 299 success=False, 300 error=f"Upload failed: {e}", 301 error_code="UPLOAD_FAILED", 302 )
ファイルをAzure Blobにアップロード
Args: file_path: アップロードするファイルのローカルパス object_name: Blob名(省略時はファイル名が使用される) prefix: オブジェクト名プレフィックス(フォルダ構造を作成) metadata: 追加メタデータ(日本語はBase64エンコードされる)
Returns: UploadResult: アップロード結果 - success: 成功/失敗 - object_name: アップロード先のBlob名 - object_url: 公開URL - file_size: ファイルサイズ(bytes) - content_type: MIMEタイプ
Example:
基本的なアップロード
result = await client.upload_file("/tmp/report.pdf") if result.success: ... print(f"Uploaded: {result.object_url}")
プレフィックス指定でフォルダ構造を作成
result = await client.upload_file( ... file_path="/tmp/image.png", ... prefix="uploads/2024/01", ... metadata={"author": "user-123"} ... )
304 async def download_file( 305 self, 306 object_name: str, 307 download_path: str, 308 ) -> DownloadResult: 309 """ 310 ファイルをダウンロード(ストリーミング) 311 312 大きなファイルでもメモリを圧迫しないよう、 313 チャンク単位でストリーミングダウンロードします。 314 315 Args: 316 object_name: ダウンロードするBlob名 317 download_path: ダウンロード先のローカルパス 318 319 Returns: 320 DownloadResult: ダウンロード結果 321 - success: 成功/失敗 322 - local_path: ダウンロード先パス 323 - file_size: ファイルサイズ(bytes) 324 - error: エラーメッセージ(失敗時) 325 326 Example: 327 >>> # Blobをダウンロード 328 >>> result = await client.download_file( 329 ... object_name="uploads/2024/01/report.pdf", 330 ... download_path="/tmp/downloaded_report.pdf" 331 ... ) 332 >>> if result.success: 333 ... print(f"Downloaded to: {result.local_path}") 334 ... print(f"Size: {result.file_size} bytes") 335 >>> else: 336 ... print(f"Error: {result.error}") 337 338 Note: 339 ダウンロード先ディレクトリが存在しない場合は自動作成されます。 340 """ 341 if not self._enabled: 342 return DownloadResult( 343 success=False, 344 error="Azure Blob Storage is not enabled", 345 error_code="STORAGE_NOT_ENABLED", 346 ) 347 348 try: 349 blob_client = self._get_blob_client(object_name) 350 351 # 存在確認 352 exists = await asyncio.get_event_loop().run_in_executor( 353 None, blob_client.exists 354 ) 355 if not exists: 356 return DownloadResult( 357 success=False, 358 object_name=object_name, 359 error=f"Blob not found: {object_name}", 360 error_code="BLOB_NOT_FOUND", 361 ) 362 363 # ダウンロード先ディレクトリ作成 364 download_path_obj = Path(download_path) 365 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 366 367 # ストリーミングダウンロード(メモリ効率化) 368 def _streaming_download(): 369 with open(download_path_obj, 'wb') as f: 370 download_stream = blob_client.download_blob() 371 # chunks()でチャンク単位でダウンロード 372 for chunk in download_stream.chunks(): 373 f.write(chunk) 374 375 await asyncio.get_event_loop().run_in_executor( 376 None, _streaming_download 377 ) 378 379 file_size = download_path_obj.stat().st_size 380 logger.info("Successfully downloaded blob") 381 382 return DownloadResult( 383 success=True, 384 object_name=object_name, 385 local_path=str(download_path_obj), 386 file_size=file_size, 387 ) 388 389 except Exception as e: 390 logger.error(f"Failed to download blob: {e}") 391 return DownloadResult( 392 success=False, 393 object_name=object_name, 394 error=f"Download failed: {e}", 395 error_code="DOWNLOAD_FAILED", 396 )
ファイルをダウンロード(ストリーミング)
大きなファイルでもメモリを圧迫しないよう、 チャンク単位でストリーミングダウンロードします。
Args: object_name: ダウンロードするBlob名 download_path: ダウンロード先のローカルパス
Returns: DownloadResult: ダウンロード結果 - success: 成功/失敗 - local_path: ダウンロード先パス - file_size: ファイルサイズ(bytes) - error: エラーメッセージ(失敗時)
Example:
Blobをダウンロード
result = await client.download_file( ... object_name="uploads/2024/01/report.pdf", ... download_path="/tmp/downloaded_report.pdf" ... ) if result.success: ... print(f"Downloaded to: {result.local_path}") ... print(f"Size: {result.file_size} bytes") else: ... print(f"Error: {result.error}")
Note: ダウンロード先ディレクトリが存在しない場合は自動作成されます。
398 async def list_objects( 399 self, 400 prefix: str = "", 401 max_results: Optional[int] = None, 402 ) -> ListResult: 403 """ 404 オブジェクト一覧を取得 405 406 指定したプレフィックスに一致するBlobの一覧を取得します。 407 408 Args: 409 prefix: プレフィックスフィルタ(フォルダパスとして機能) 410 max_results: 最大取得数(省略時は全件取得) 411 412 Returns: 413 ListResult: 一覧結果 414 - success: 成功/失敗 415 - objects: ObjectInfoのリスト 416 - count: 取得件数 417 418 Example: 419 >>> # 全オブジェクト一覧 420 >>> result = await client.list_objects() 421 >>> print(f"Total: {result.count} objects") 422 >>> 423 >>> # プレフィックスでフィルタ 424 >>> result = await client.list_objects( 425 ... prefix="uploads/2024/", 426 ... max_results=100 427 ... ) 428 >>> for obj in result.objects: 429 ... print(f"{obj.name}: {obj.size} bytes") 430 """ 431 if not self._enabled: 432 return ListResult( 433 success=False, 434 error="Azure Blob Storage is not enabled", 435 error_code="STORAGE_NOT_ENABLED", 436 ) 437 438 try: 439 container_client = self._get_container_client() 440 441 blob_list = await asyncio.get_event_loop().run_in_executor( 442 None, 443 lambda: list(container_client.list_blobs(name_starts_with=prefix)) 444 ) 445 446 objects = [] 447 for blob in blob_list: 448 if max_results and len(objects) >= max_results: 449 break 450 objects.append(ObjectInfo( 451 name=blob.name, 452 size=blob.size, 453 last_modified=blob.last_modified.isoformat() if blob.last_modified else None, 454 content_type=blob.content_settings.content_type if blob.content_settings else None, 455 )) 456 457 logger.info(f"Listed {len(objects)} blobs") 458 459 return ListResult( 460 success=True, 461 objects=objects, 462 count=len(objects), 463 prefix=prefix, 464 ) 465 466 except Exception as e: 467 logger.error(f"Failed to list blobs: {e}") 468 return ListResult( 469 success=False, 470 error=f"List failed: {e}", 471 error_code="LIST_FAILED", 472 )
オブジェクト一覧を取得
指定したプレフィックスに一致するBlobの一覧を取得します。
Args: prefix: プレフィックスフィルタ(フォルダパスとして機能) max_results: 最大取得数(省略時は全件取得)
Returns: ListResult: 一覧結果 - success: 成功/失敗 - objects: ObjectInfoのリスト - count: 取得件数
Example:
全オブジェクト一覧
result = await client.list_objects() print(f"Total: {result.count} objects")
プレフィックスでフィルタ
result = await client.list_objects( ... prefix="uploads/2024/", ... max_results=100 ... ) for obj in result.objects: ... print(f"{obj.name}: {obj.size} bytes")
474 async def delete_object(self, object_name: str) -> bool: 475 """ 476 オブジェクトを削除 477 478 Args: 479 object_name: 削除するBlob名 480 481 Returns: 482 成功した場合True 483 """ 484 if not self._enabled: 485 return False 486 487 try: 488 blob_client = self._get_blob_client(object_name) 489 490 await asyncio.get_event_loop().run_in_executor( 491 None, 492 lambda: blob_client.delete_blob(delete_snapshots="include") 493 ) 494 495 logger.info("Successfully deleted blob") 496 return True 497 498 except Exception as e: 499 logger.error(f"Failed to delete blob: {e}") 500 return False
オブジェクトを削除
Args: object_name: 削除するBlob名
Returns: 成功した場合True
502 async def object_exists(self, object_name: str) -> bool: 503 """ 504 オブジェクトの存在確認 505 506 Args: 507 object_name: 確認するBlob名 508 509 Returns: 510 存在する場合True 511 """ 512 if not self._enabled: 513 return False 514 515 try: 516 blob_client = self._get_blob_client(object_name) 517 return await asyncio.get_event_loop().run_in_executor( 518 None, blob_client.exists 519 ) 520 except Exception as e: 521 logger.error(f"Failed to check blob existence: {e}") 522 return False
オブジェクトの存在確認
Args: object_name: 確認するBlob名
Returns: 存在する場合True
524 async def get_object_url( 525 self, 526 object_name: str, 527 expires_in: Optional[int] = None, 528 ) -> str: 529 """ 530 オブジェクトのURLを取得 531 532 Args: 533 object_name: Blob名 534 expires_in: SAS有効期限(秒)- 未実装、将来拡張用 535 536 Returns: 537 オブジェクトURL 538 """ 539 if not self._enabled: 540 return "" 541 542 blob_client = self._get_blob_client(object_name) 543 return self._generate_public_url(blob_client.url)
オブジェクトのURLを取得
Args: object_name: Blob名 expires_in: SAS有効期限(秒)- 未実装、将来拡張用
Returns: オブジェクトURL
563 async def download_objects_by_prefix( 564 self, 565 prefix: str, 566 download_dir: str, 567 ) -> Dict[str, Any]: 568 """ 569 プレフィックスに一致するオブジェクトをすべてダウンロード 570 571 Args: 572 prefix: プレフィックス 573 download_dir: ダウンロード先ディレクトリ 574 575 Returns: 576 ダウンロード結果サマリー 577 578 Security: 579 パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 580 ダウンロード先がdownload_dir配下であることを検証します。 581 """ 582 list_result = await self.list_objects(prefix) 583 if not list_result.success: 584 return { 585 "success": False, 586 "error": list_result.error, 587 "downloaded": [], 588 "failed": [], 589 } 590 591 download_dir_obj = Path(download_dir).resolve() 592 downloaded = [] 593 failed = [] 594 skipped = [] 595 596 for obj in list_result.objects: 597 # プレフィックス除去 598 relative_path = obj.name 599 if obj.name.startswith(prefix): 600 relative_path = obj.name[len(prefix):] 601 602 # パスサニタイズ(パストラバーサル防止) 603 relative_path = self._sanitize_path(relative_path) 604 605 if not relative_path: 606 skipped.append({ 607 "object_name": obj.name, 608 "reason": "Invalid path after sanitization", 609 }) 610 continue 611 612 download_path = (download_dir_obj / relative_path).resolve() 613 614 # パストラバーサル検証: download_dir配下であることを確認 615 try: 616 download_path.relative_to(download_dir_obj) 617 except ValueError: 618 logger.warning( 619 f"Path traversal attempt detected, skipping object" 620 ) 621 skipped.append({ 622 "object_name": obj.name, 623 "reason": "Path traversal detected", 624 }) 625 continue 626 627 result = await self.download_file(obj.name, str(download_path)) 628 if result.success: 629 downloaded.append({ 630 "object_name": obj.name, 631 "local_path": result.local_path, 632 "file_size": result.file_size, 633 }) 634 else: 635 failed.append({ 636 "object_name": obj.name, 637 "error": result.error, 638 }) 639 640 return { 641 "success": True, 642 "downloaded": downloaded, 643 "failed": failed, 644 "skipped": skipped, 645 "success_count": len(downloaded), 646 "failed_count": len(failed), 647 "skipped_count": len(skipped), 648 }
プレフィックスに一致するオブジェクトをすべてダウンロード
Args: prefix: プレフィックス download_dir: ダウンロード先ディレクトリ
Returns: ダウンロード結果サマリー
Security: パストラバーサル攻撃を防ぐため、オブジェクト名をサニタイズし、 ダウンロード先がdownload_dir配下であることを検証します。
650 async def close(self) -> None: 651 """クライアントリソースをクローズ 652 653 Note: 654 例外が発生してもリソースを確実にクリーンアップします。 655 """ 656 try: 657 if self._blob_service_client: 658 # BlobServiceClientはclose()メソッドを持たないが、 659 # 将来のバージョンで追加される可能性があるため確認 660 if hasattr(self._blob_service_client, 'close'): 661 self._blob_service_client.close() 662 except Exception as e: 663 logger.warning(f"Error closing Azure Blob client: {e}") 664 finally: 665 self._blob_service_client = None 666 self._enabled = False 667 logger.debug("AzureBlobStorageClient closed")
クライアントリソースをクローズ
Note: 例外が発生してもリソースを確実にクリーンアップします。
48class S3StorageClient(StorageClientBase): 49 """ 50 AWS S3ストレージクライアント 51 52 Features: 53 - ファイルアップロード・ダウンロード 54 - オブジェクト一覧・削除 55 - 署名付きURL生成(presigned URL) 56 - カスタムドメイン対応 57 - 日本語ファイル名対応(URLエンコード) 58 - メタデータサニタイズ(Base64エンコード) 59 60 Example: 61 config = S3Config( 62 bucket_name="my-bucket", 63 aws_access_key_id="AKIA...", 64 aws_secret_access_key="...", 65 region_name="ap-northeast-1", 66 ) 67 client = S3StorageClient(config) 68 result = await client.upload_file("/path/to/file.txt") 69 """ 70 71 def __init__(self, config: S3Config): 72 """ 73 Initialize AWS S3 Storage Client 74 75 Args: 76 config: S3Config instance 77 78 Raises: 79 StorageConfigError: 設定が不正な場合 80 StorageConnectionError: 接続に失敗した場合 81 """ 82 super().__init__(config) 83 84 if not S3_AVAILABLE: 85 raise StorageConfigError( 86 "boto3 package is not installed. " 87 "Install with: pip install boto3" 88 ) 89 90 if not config.bucket_name: 91 raise StorageConfigError("S3 bucket_name is required") 92 93 if not config.aws_access_key_id or not config.aws_secret_access_key: 94 raise StorageConfigError( 95 "S3 requires aws_access_key_id and aws_secret_access_key" 96 ) 97 98 self._s3_config: S3Config = config 99 self._s3_client = None 100 101 try: 102 session = boto3.Session( 103 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 104 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 105 region_name=config.region_name, 106 ) 107 108 client_kwargs = {} 109 if config.endpoint_url: 110 client_kwargs['endpoint_url'] = config.endpoint_url 111 112 self._s3_client = session.client('s3', **client_kwargs) 113 self._enabled = config.enabled 114 115 logger.info( 116 f"S3StorageClient initialized - bucket={config.bucket_name}, " 117 f"region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize S3 client: {e}") 122 raise StorageConnectionError( 123 f"Failed to connect to S3: {e}" 124 ) from e 125 126 @property 127 def provider(self) -> StorageProvider: 128 return StorageProvider.AWS_S3 129 130 @property 131 def bucket_name(self) -> str: 132 return self._s3_config.bucket_name 133 134 # _encode_object_name: StorageClientBase に移動済み 135 # _sanitize_path: StorageClientBase に移動済み 136 137 def _generate_public_url(self, object_name: str) -> str: 138 """公開URLを生成(カスタムドメイン対応)""" 139 custom_domain = self._s3_config.custom_domain 140 if custom_domain: 141 return f"https://{custom_domain}/{object_name}" 142 143 if self._s3_config.endpoint_url: 144 return f"{self._s3_config.endpoint_url}/{self.bucket_name}/{object_name}" 145 146 return f"https://{self.bucket_name}.s3.{self._s3_config.region_name}.amazonaws.com/{object_name}" 147 148 async def ensure_bucket_exists(self) -> bool: 149 """バケットの存在確認・作成""" 150 if not self._enabled: 151 return False 152 153 try: 154 await asyncio.get_running_loop().run_in_executor( 155 None, 156 lambda: self._s3_client.head_bucket(Bucket=self.bucket_name) 157 ) 158 return True 159 except ClientError as e: 160 error_code = e.response.get('Error', {}).get('Code', '') 161 if error_code in ('404', 'NoSuchBucket'): 162 if self._s3_config.auto_create_bucket: 163 try: 164 create_params = {'Bucket': self.bucket_name} 165 if self._s3_config.region_name != 'us-east-1': 166 create_params['CreateBucketConfiguration'] = { 167 'LocationConstraint': self._s3_config.region_name 168 } 169 await asyncio.get_running_loop().run_in_executor( 170 None, 171 lambda: self._s3_client.create_bucket(**create_params) 172 ) 173 logger.info(f"Created S3 bucket: {self.bucket_name}") 174 return True 175 except Exception as create_error: 176 logger.error(f"Failed to create bucket: {create_error}") 177 return False 178 else: 179 logger.error(f"Bucket {self.bucket_name} does not exist") 180 return False 181 logger.error(f"Failed to check bucket: {e}") 182 return False 183 except Exception as e: 184 logger.error(f"Failed to ensure bucket exists: {e}") 185 return False 186 187 async def upload_file( 188 self, 189 file_path: str, 190 object_name: Optional[str] = None, 191 prefix: Optional[str] = None, 192 metadata: Optional[Dict[str, Any]] = None, 193 ) -> UploadResult: 194 """ 195 ファイルをS3にアップロード 196 197 Args: 198 file_path: アップロードするファイルのローカルパス 199 object_name: オブジェクト名(省略時はファイル名が使用される) 200 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 201 metadata: 追加メタデータ(日本語はBase64エンコードされる) 202 203 Returns: 204 UploadResult: アップロード結果 205 """ 206 if not self._enabled: 207 return UploadResult( 208 success=False, 209 error="S3 Storage is not enabled", 210 error_code="STORAGE_NOT_ENABLED", 211 ) 212 213 try: 214 self._validate_file_size(file_path) 215 216 file_path_obj = Path(file_path) 217 218 if not object_name: 219 object_name = file_path_obj.name 220 221 safe_object_name = self._encode_object_name(object_name) 222 223 if prefix: 224 full_object_name = f"{prefix}/{safe_object_name}" 225 else: 226 full_object_name = safe_object_name 227 228 if not await self.ensure_bucket_exists(): 229 return UploadResult( 230 success=False, 231 error="Bucket is not available", 232 error_code="BUCKET_NOT_AVAILABLE", 233 ) 234 235 content_type = self._get_content_type(file_path) 236 safe_metadata = self._sanitize_metadata(metadata) 237 238 extra_args = { 239 'ContentType': content_type, 240 } 241 if safe_metadata: 242 extra_args['Metadata'] = safe_metadata 243 244 file_size = file_path_obj.stat().st_size 245 246 await asyncio.get_running_loop().run_in_executor( 247 None, 248 lambda: self._s3_client.upload_file( 249 str(file_path_obj), 250 self.bucket_name, 251 full_object_name, 252 ExtraArgs=extra_args, 253 ) 254 ) 255 256 object_url = self._generate_public_url(full_object_name) 257 258 logger.info("Successfully uploaded file to S3") 259 260 return UploadResult( 261 success=True, 262 object_name=full_object_name, 263 object_url=object_url, 264 file_size=file_size, 265 content_type=content_type, 266 metadata={ 267 "uploaded_at": time.time(), 268 "prefix": prefix, 269 **(safe_metadata or {}), 270 }, 271 ) 272 273 except StorageOperationError: 274 raise 275 except Exception as e: 276 logger.error(f"Failed to upload file: {e}") 277 return UploadResult( 278 success=False, 279 error=f"Upload failed: {e}", 280 error_code="UPLOAD_FAILED", 281 ) 282 283 async def download_file( 284 self, 285 object_name: str, 286 download_path: str, 287 ) -> DownloadResult: 288 """ 289 ファイルをS3からダウンロード(ストリーミング) 290 291 Args: 292 object_name: ダウンロードするオブジェクト名 293 download_path: ダウンロード先のローカルパス 294 295 Returns: 296 DownloadResult: ダウンロード結果 297 """ 298 if not self._enabled: 299 return DownloadResult( 300 success=False, 301 error="S3 Storage is not enabled", 302 error_code="STORAGE_NOT_ENABLED", 303 ) 304 305 try: 306 # 存在確認 307 exists = await self.object_exists(object_name) 308 if not exists: 309 return DownloadResult( 310 success=False, 311 object_name=object_name, 312 error=f"Object not found: {object_name}", 313 error_code="OBJECT_NOT_FOUND", 314 ) 315 316 download_path_obj = Path(download_path) 317 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 318 319 await asyncio.get_running_loop().run_in_executor( 320 None, 321 lambda: self._s3_client.download_file( 322 self.bucket_name, 323 object_name, 324 str(download_path_obj), 325 ) 326 ) 327 328 file_size = download_path_obj.stat().st_size 329 logger.info("Successfully downloaded object from S3") 330 331 return DownloadResult( 332 success=True, 333 object_name=object_name, 334 local_path=str(download_path_obj), 335 file_size=file_size, 336 ) 337 338 except Exception as e: 339 logger.error(f"Failed to download object: {e}") 340 return DownloadResult( 341 success=False, 342 object_name=object_name, 343 error=f"Download failed: {e}", 344 error_code="DOWNLOAD_FAILED", 345 ) 346 347 async def list_objects( 348 self, 349 prefix: str = "", 350 max_results: Optional[int] = None, 351 ) -> ListResult: 352 """ 353 オブジェクト一覧を取得 354 355 Args: 356 prefix: プレフィックスフィルタ 357 max_results: 最大取得数 358 359 Returns: 360 ListResult: 一覧結果 361 """ 362 if not self._enabled: 363 return ListResult( 364 success=False, 365 error="S3 Storage is not enabled", 366 error_code="STORAGE_NOT_ENABLED", 367 ) 368 369 try: 370 paginator = self._s3_client.get_paginator('list_objects_v2') 371 page_config = { 372 'Bucket': self.bucket_name, 373 'Prefix': prefix, 374 } 375 if max_results: 376 page_config['PaginationConfig'] = {'MaxItems': max_results} 377 378 def _list(): 379 objects = [] 380 for page in paginator.paginate(**page_config): 381 for obj in page.get('Contents', []): 382 objects.append(ObjectInfo( 383 name=obj['Key'], 384 size=obj.get('Size', 0), 385 last_modified=obj['LastModified'].isoformat() if obj.get('LastModified') else None, 386 content_type=None, 387 )) 388 return objects 389 390 objects = await asyncio.get_running_loop().run_in_executor(None, _list) 391 392 logger.info(f"Listed {len(objects)} objects from S3") 393 394 return ListResult( 395 success=True, 396 objects=objects, 397 count=len(objects), 398 prefix=prefix, 399 ) 400 401 except Exception as e: 402 logger.error(f"Failed to list objects: {e}") 403 return ListResult( 404 success=False, 405 error=f"List failed: {e}", 406 error_code="LIST_FAILED", 407 ) 408 409 async def delete_object(self, object_name: str) -> bool: 410 """オブジェクトを削除""" 411 if not self._enabled: 412 return False 413 414 try: 415 await asyncio.get_running_loop().run_in_executor( 416 None, 417 lambda: self._s3_client.delete_object( 418 Bucket=self.bucket_name, 419 Key=object_name, 420 ) 421 ) 422 logger.info("Successfully deleted object from S3") 423 return True 424 425 except Exception as e: 426 logger.error(f"Failed to delete object: {e}") 427 return False 428 429 async def object_exists(self, object_name: str) -> bool: 430 """オブジェクトの存在確認""" 431 if not self._enabled: 432 return False 433 434 try: 435 await asyncio.get_running_loop().run_in_executor( 436 None, 437 lambda: self._s3_client.head_object( 438 Bucket=self.bucket_name, 439 Key=object_name, 440 ) 441 ) 442 return True 443 except ClientError as e: 444 if e.response.get('Error', {}).get('Code', '') == '404': 445 return False 446 logger.error(f"Failed to check object existence: {e}") 447 return False 448 except Exception as e: 449 logger.error(f"Failed to check object existence: {e}") 450 return False 451 452 async def get_object_url( 453 self, 454 object_name: str, 455 expires_in: Optional[int] = None, 456 ) -> str: 457 """ 458 署名付きURLを生成 459 460 Args: 461 object_name: オブジェクト名 462 expires_in: 有効期限(秒)。省略時は公開URLを返す 463 464 Returns: 465 署名付きURL or 公開URL 466 """ 467 if not self._enabled: 468 return "" 469 470 if expires_in: 471 try: 472 url = await asyncio.get_running_loop().run_in_executor( 473 None, 474 lambda: self._s3_client.generate_presigned_url( 475 'get_object', 476 Params={ 477 'Bucket': self.bucket_name, 478 'Key': object_name, 479 }, 480 ExpiresIn=expires_in, 481 ) 482 ) 483 return url 484 except Exception as e: 485 logger.error(f"Failed to generate presigned URL: {e}") 486 return "" 487 488 return self._generate_public_url(object_name) 489 490 # download_objects_by_prefix: StorageClientBase に移動済み 491 492 async def close(self) -> None: 493 """クライアントリソースをクローズ""" 494 try: 495 if self._s3_client: 496 self._s3_client.close() 497 except Exception as e: 498 logger.warning(f"Error during S3 client close: {e}") 499 finally: 500 self._s3_client = None 501 self._enabled = False 502 logger.debug("S3StorageClient closed")
AWS S3ストレージクライアント
Features:
- ファイルアップロード・ダウンロード
- オブジェクト一覧・削除
- 署名付きURL生成(presigned URL)
- カスタムドメイン対応
- 日本語ファイル名対応(URLエンコード)
- メタデータサニタイズ(Base64エンコード)
Example: config = S3Config( bucket_name="my-bucket", aws_access_key_id="AKIA...", aws_secret_access_key="...", region_name="ap-northeast-1", ) client = S3StorageClient(config) result = await client.upload_file("/path/to/file.txt")
71 def __init__(self, config: S3Config): 72 """ 73 Initialize AWS S3 Storage Client 74 75 Args: 76 config: S3Config instance 77 78 Raises: 79 StorageConfigError: 設定が不正な場合 80 StorageConnectionError: 接続に失敗した場合 81 """ 82 super().__init__(config) 83 84 if not S3_AVAILABLE: 85 raise StorageConfigError( 86 "boto3 package is not installed. " 87 "Install with: pip install boto3" 88 ) 89 90 if not config.bucket_name: 91 raise StorageConfigError("S3 bucket_name is required") 92 93 if not config.aws_access_key_id or not config.aws_secret_access_key: 94 raise StorageConfigError( 95 "S3 requires aws_access_key_id and aws_secret_access_key" 96 ) 97 98 self._s3_config: S3Config = config 99 self._s3_client = None 100 101 try: 102 session = boto3.Session( 103 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 104 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 105 region_name=config.region_name, 106 ) 107 108 client_kwargs = {} 109 if config.endpoint_url: 110 client_kwargs['endpoint_url'] = config.endpoint_url 111 112 self._s3_client = session.client('s3', **client_kwargs) 113 self._enabled = config.enabled 114 115 logger.info( 116 f"S3StorageClient initialized - bucket={config.bucket_name}, " 117 f"region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize S3 client: {e}") 122 raise StorageConnectionError( 123 f"Failed to connect to S3: {e}" 124 ) from e
Initialize AWS S3 Storage Client
Args: config: S3Config instance
Raises: StorageConfigError: 設定が不正な場合 StorageConnectionError: 接続に失敗した場合
148 async def ensure_bucket_exists(self) -> bool: 149 """バケットの存在確認・作成""" 150 if not self._enabled: 151 return False 152 153 try: 154 await asyncio.get_running_loop().run_in_executor( 155 None, 156 lambda: self._s3_client.head_bucket(Bucket=self.bucket_name) 157 ) 158 return True 159 except ClientError as e: 160 error_code = e.response.get('Error', {}).get('Code', '') 161 if error_code in ('404', 'NoSuchBucket'): 162 if self._s3_config.auto_create_bucket: 163 try: 164 create_params = {'Bucket': self.bucket_name} 165 if self._s3_config.region_name != 'us-east-1': 166 create_params['CreateBucketConfiguration'] = { 167 'LocationConstraint': self._s3_config.region_name 168 } 169 await asyncio.get_running_loop().run_in_executor( 170 None, 171 lambda: self._s3_client.create_bucket(**create_params) 172 ) 173 logger.info(f"Created S3 bucket: {self.bucket_name}") 174 return True 175 except Exception as create_error: 176 logger.error(f"Failed to create bucket: {create_error}") 177 return False 178 else: 179 logger.error(f"Bucket {self.bucket_name} does not exist") 180 return False 181 logger.error(f"Failed to check bucket: {e}") 182 return False 183 except Exception as e: 184 logger.error(f"Failed to ensure bucket exists: {e}") 185 return False
バケットの存在確認・作成
187 async def upload_file( 188 self, 189 file_path: str, 190 object_name: Optional[str] = None, 191 prefix: Optional[str] = None, 192 metadata: Optional[Dict[str, Any]] = None, 193 ) -> UploadResult: 194 """ 195 ファイルをS3にアップロード 196 197 Args: 198 file_path: アップロードするファイルのローカルパス 199 object_name: オブジェクト名(省略時はファイル名が使用される) 200 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 201 metadata: 追加メタデータ(日本語はBase64エンコードされる) 202 203 Returns: 204 UploadResult: アップロード結果 205 """ 206 if not self._enabled: 207 return UploadResult( 208 success=False, 209 error="S3 Storage is not enabled", 210 error_code="STORAGE_NOT_ENABLED", 211 ) 212 213 try: 214 self._validate_file_size(file_path) 215 216 file_path_obj = Path(file_path) 217 218 if not object_name: 219 object_name = file_path_obj.name 220 221 safe_object_name = self._encode_object_name(object_name) 222 223 if prefix: 224 full_object_name = f"{prefix}/{safe_object_name}" 225 else: 226 full_object_name = safe_object_name 227 228 if not await self.ensure_bucket_exists(): 229 return UploadResult( 230 success=False, 231 error="Bucket is not available", 232 error_code="BUCKET_NOT_AVAILABLE", 233 ) 234 235 content_type = self._get_content_type(file_path) 236 safe_metadata = self._sanitize_metadata(metadata) 237 238 extra_args = { 239 'ContentType': content_type, 240 } 241 if safe_metadata: 242 extra_args['Metadata'] = safe_metadata 243 244 file_size = file_path_obj.stat().st_size 245 246 await asyncio.get_running_loop().run_in_executor( 247 None, 248 lambda: self._s3_client.upload_file( 249 str(file_path_obj), 250 self.bucket_name, 251 full_object_name, 252 ExtraArgs=extra_args, 253 ) 254 ) 255 256 object_url = self._generate_public_url(full_object_name) 257 258 logger.info("Successfully uploaded file to S3") 259 260 return UploadResult( 261 success=True, 262 object_name=full_object_name, 263 object_url=object_url, 264 file_size=file_size, 265 content_type=content_type, 266 metadata={ 267 "uploaded_at": time.time(), 268 "prefix": prefix, 269 **(safe_metadata or {}), 270 }, 271 ) 272 273 except StorageOperationError: 274 raise 275 except Exception as e: 276 logger.error(f"Failed to upload file: {e}") 277 return UploadResult( 278 success=False, 279 error=f"Upload failed: {e}", 280 error_code="UPLOAD_FAILED", 281 )
ファイルをS3にアップロード
Args: file_path: アップロードするファイルのローカルパス object_name: オブジェクト名(省略時はファイル名が使用される) prefix: オブジェクト名プレフィックス(フォルダ構造を作成) metadata: 追加メタデータ(日本語はBase64エンコードされる)
Returns: UploadResult: アップロード結果
283 async def download_file( 284 self, 285 object_name: str, 286 download_path: str, 287 ) -> DownloadResult: 288 """ 289 ファイルをS3からダウンロード(ストリーミング) 290 291 Args: 292 object_name: ダウンロードするオブジェクト名 293 download_path: ダウンロード先のローカルパス 294 295 Returns: 296 DownloadResult: ダウンロード結果 297 """ 298 if not self._enabled: 299 return DownloadResult( 300 success=False, 301 error="S3 Storage is not enabled", 302 error_code="STORAGE_NOT_ENABLED", 303 ) 304 305 try: 306 # 存在確認 307 exists = await self.object_exists(object_name) 308 if not exists: 309 return DownloadResult( 310 success=False, 311 object_name=object_name, 312 error=f"Object not found: {object_name}", 313 error_code="OBJECT_NOT_FOUND", 314 ) 315 316 download_path_obj = Path(download_path) 317 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 318 319 await asyncio.get_running_loop().run_in_executor( 320 None, 321 lambda: self._s3_client.download_file( 322 self.bucket_name, 323 object_name, 324 str(download_path_obj), 325 ) 326 ) 327 328 file_size = download_path_obj.stat().st_size 329 logger.info("Successfully downloaded object from S3") 330 331 return DownloadResult( 332 success=True, 333 object_name=object_name, 334 local_path=str(download_path_obj), 335 file_size=file_size, 336 ) 337 338 except Exception as e: 339 logger.error(f"Failed to download object: {e}") 340 return DownloadResult( 341 success=False, 342 object_name=object_name, 343 error=f"Download failed: {e}", 344 error_code="DOWNLOAD_FAILED", 345 )
ファイルをS3からダウンロード(ストリーミング)
Args: object_name: ダウンロードするオブジェクト名 download_path: ダウンロード先のローカルパス
Returns: DownloadResult: ダウンロード結果
347 async def list_objects( 348 self, 349 prefix: str = "", 350 max_results: Optional[int] = None, 351 ) -> ListResult: 352 """ 353 オブジェクト一覧を取得 354 355 Args: 356 prefix: プレフィックスフィルタ 357 max_results: 最大取得数 358 359 Returns: 360 ListResult: 一覧結果 361 """ 362 if not self._enabled: 363 return ListResult( 364 success=False, 365 error="S3 Storage is not enabled", 366 error_code="STORAGE_NOT_ENABLED", 367 ) 368 369 try: 370 paginator = self._s3_client.get_paginator('list_objects_v2') 371 page_config = { 372 'Bucket': self.bucket_name, 373 'Prefix': prefix, 374 } 375 if max_results: 376 page_config['PaginationConfig'] = {'MaxItems': max_results} 377 378 def _list(): 379 objects = [] 380 for page in paginator.paginate(**page_config): 381 for obj in page.get('Contents', []): 382 objects.append(ObjectInfo( 383 name=obj['Key'], 384 size=obj.get('Size', 0), 385 last_modified=obj['LastModified'].isoformat() if obj.get('LastModified') else None, 386 content_type=None, 387 )) 388 return objects 389 390 objects = await asyncio.get_running_loop().run_in_executor(None, _list) 391 392 logger.info(f"Listed {len(objects)} objects from S3") 393 394 return ListResult( 395 success=True, 396 objects=objects, 397 count=len(objects), 398 prefix=prefix, 399 ) 400 401 except Exception as e: 402 logger.error(f"Failed to list objects: {e}") 403 return ListResult( 404 success=False, 405 error=f"List failed: {e}", 406 error_code="LIST_FAILED", 407 )
オブジェクト一覧を取得
Args: prefix: プレフィックスフィルタ max_results: 最大取得数
Returns: ListResult: 一覧結果
409 async def delete_object(self, object_name: str) -> bool: 410 """オブジェクトを削除""" 411 if not self._enabled: 412 return False 413 414 try: 415 await asyncio.get_running_loop().run_in_executor( 416 None, 417 lambda: self._s3_client.delete_object( 418 Bucket=self.bucket_name, 419 Key=object_name, 420 ) 421 ) 422 logger.info("Successfully deleted object from S3") 423 return True 424 425 except Exception as e: 426 logger.error(f"Failed to delete object: {e}") 427 return False
オブジェクトを削除
429 async def object_exists(self, object_name: str) -> bool: 430 """オブジェクトの存在確認""" 431 if not self._enabled: 432 return False 433 434 try: 435 await asyncio.get_running_loop().run_in_executor( 436 None, 437 lambda: self._s3_client.head_object( 438 Bucket=self.bucket_name, 439 Key=object_name, 440 ) 441 ) 442 return True 443 except ClientError as e: 444 if e.response.get('Error', {}).get('Code', '') == '404': 445 return False 446 logger.error(f"Failed to check object existence: {e}") 447 return False 448 except Exception as e: 449 logger.error(f"Failed to check object existence: {e}") 450 return False
オブジェクトの存在確認
452 async def get_object_url( 453 self, 454 object_name: str, 455 expires_in: Optional[int] = None, 456 ) -> str: 457 """ 458 署名付きURLを生成 459 460 Args: 461 object_name: オブジェクト名 462 expires_in: 有効期限(秒)。省略時は公開URLを返す 463 464 Returns: 465 署名付きURL or 公開URL 466 """ 467 if not self._enabled: 468 return "" 469 470 if expires_in: 471 try: 472 url = await asyncio.get_running_loop().run_in_executor( 473 None, 474 lambda: self._s3_client.generate_presigned_url( 475 'get_object', 476 Params={ 477 'Bucket': self.bucket_name, 478 'Key': object_name, 479 }, 480 ExpiresIn=expires_in, 481 ) 482 ) 483 return url 484 except Exception as e: 485 logger.error(f"Failed to generate presigned URL: {e}") 486 return "" 487 488 return self._generate_public_url(object_name)
署名付きURLを生成
Args: object_name: オブジェクト名 expires_in: 有効期限(秒)。省略時は公開URLを返す
Returns: 署名付きURL or 公開URL
492 async def close(self) -> None: 493 """クライアントリソースをクローズ""" 494 try: 495 if self._s3_client: 496 self._s3_client.close() 497 except Exception as e: 498 logger.warning(f"Error during S3 client close: {e}") 499 finally: 500 self._s3_client = None 501 self._enabled = False 502 logger.debug("S3StorageClient closed")
クライアントリソースをクローズ
48class GCSStorageClient(StorageClientBase): 49 """ 50 Google Cloud Storageクライアント 51 52 Features: 53 - ファイルアップロード・ダウンロード 54 - オブジェクト一覧・削除 55 - 署名付きURL生成 56 - カスタムドメイン対応 57 - 日本語ファイル名対応(URLエンコード) 58 - メタデータサニタイズ(Base64エンコード) 59 60 Example: 61 config = GCSConfig( 62 bucket_name="my-bucket", 63 project_id="my-project", 64 credentials_path="/path/to/service-account.json", 65 ) 66 client = GCSStorageClient(config) 67 result = await client.upload_file("/path/to/file.txt") 68 """ 69 70 def __init__(self, config: GCSConfig): 71 """ 72 Initialize Google Cloud Storage Client 73 74 Args: 75 config: GCSConfig instance 76 77 Raises: 78 StorageConfigError: 設定が不正な場合 79 StorageConnectionError: 接続に失敗した場合 80 """ 81 super().__init__(config) 82 83 if not GCS_AVAILABLE: 84 raise StorageConfigError( 85 "google-cloud-storage package is not installed. " 86 "Install with: pip install google-cloud-storage" 87 ) 88 89 if not config.bucket_name: 90 raise StorageConfigError("GCS bucket_name is required") 91 92 if not config.project_id: 93 raise StorageConfigError("GCS project_id is required") 94 95 self._gcs_config: GCSConfig = config 96 self._gcs_client = None 97 self._bucket = None 98 99 try: 100 credentials_json = config.credentials_json 101 # Base64エンコード済みの場合はデコード (Helm TOML-safe format) 102 if not credentials_json and config.credentials_base64: 103 import base64 104 credentials_json = base64.b64decode(config.credentials_base64).decode('utf-8') 105 106 if config.credentials_path: 107 credentials = service_account.Credentials.from_service_account_file( 108 config.credentials_path 109 ) 110 self._gcs_client = gcs_storage.Client( 111 project=config.project_id, 112 credentials=credentials, 113 ) 114 elif credentials_json: 115 import json 116 credentials_info = json.loads(credentials_json) 117 credentials = service_account.Credentials.from_service_account_info( 118 credentials_info 119 ) 120 self._gcs_client = gcs_storage.Client( 121 project=config.project_id, 122 credentials=credentials, 123 ) 124 else: 125 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 126 self._gcs_client = gcs_storage.Client(project=config.project_id) 127 128 self._enabled = config.enabled 129 130 logger.info( 131 f"GCSStorageClient initialized - bucket={config.bucket_name}, " 132 f"project={config.project_id}" 133 ) 134 135 except Exception as e: 136 logger.error(f"Failed to initialize GCS client: {e}") 137 raise StorageConnectionError( 138 f"Failed to connect to GCS: {e}" 139 ) from e 140 141 @property 142 def provider(self) -> StorageProvider: 143 return StorageProvider.GCS 144 145 @property 146 def bucket_name(self) -> str: 147 return self._gcs_config.bucket_name 148 149 # _encode_object_name: StorageClientBase に移動済み 150 # _sanitize_path: StorageClientBase に移動済み 151 152 def _generate_public_url(self, object_name: str) -> str: 153 """公開URLを生成(カスタムドメイン対応)""" 154 custom_domain = self._gcs_config.custom_domain 155 if custom_domain: 156 return f"https://{custom_domain}/{object_name}" 157 158 return f"https://storage.googleapis.com/{self.bucket_name}/{object_name}" 159 160 async def ensure_bucket_exists(self) -> bool: 161 """バケットの存在確認・作成""" 162 if not self._enabled: 163 return False 164 165 try: 166 self._bucket = self._gcs_client.bucket(self.bucket_name) 167 168 exists = await asyncio.get_running_loop().run_in_executor( 169 None, self._bucket.exists 170 ) 171 if exists: 172 return True 173 174 if self._gcs_config.auto_create_bucket: 175 self._bucket = await asyncio.get_running_loop().run_in_executor( 176 None, 177 lambda: self._gcs_client.create_bucket(self.bucket_name) 178 ) 179 logger.info(f"Created GCS bucket: {self.bucket_name}") 180 return True 181 else: 182 logger.error(f"Bucket {self.bucket_name} does not exist") 183 return False 184 185 except Exception as e: 186 logger.error(f"Failed to ensure bucket exists: {e}") 187 return False 188 189 def _get_bucket(self): 190 """バケットオブジェクトを取得(キャッシュ付き)""" 191 if self._bucket is None: 192 self._bucket = self._gcs_client.bucket(self.bucket_name) 193 return self._bucket 194 195 async def upload_file( 196 self, 197 file_path: str, 198 object_name: Optional[str] = None, 199 prefix: Optional[str] = None, 200 metadata: Optional[Dict[str, Any]] = None, 201 ) -> UploadResult: 202 """ 203 ファイルをGCSにアップロード 204 205 Args: 206 file_path: アップロードするファイルのローカルパス 207 object_name: オブジェクト名(省略時はファイル名が使用される) 208 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 209 metadata: 追加メタデータ(日本語はBase64エンコードされる) 210 211 Returns: 212 UploadResult: アップロード結果 213 """ 214 if not self._enabled: 215 return UploadResult( 216 success=False, 217 error="GCS Storage is not enabled", 218 error_code="STORAGE_NOT_ENABLED", 219 ) 220 221 try: 222 self._validate_file_size(file_path) 223 224 file_path_obj = Path(file_path) 225 226 if not object_name: 227 object_name = file_path_obj.name 228 229 safe_object_name = self._encode_object_name(object_name) 230 231 if prefix: 232 full_object_name = f"{prefix}/{safe_object_name}" 233 else: 234 full_object_name = safe_object_name 235 236 if not await self.ensure_bucket_exists(): 237 return UploadResult( 238 success=False, 239 error="Bucket is not available", 240 error_code="BUCKET_NOT_AVAILABLE", 241 ) 242 243 content_type = self._get_content_type(file_path) 244 safe_metadata = self._sanitize_metadata(metadata) 245 246 bucket = self._get_bucket() 247 blob = bucket.blob(full_object_name) 248 249 if safe_metadata: 250 blob.metadata = safe_metadata 251 252 file_size = file_path_obj.stat().st_size 253 254 await asyncio.get_running_loop().run_in_executor( 255 None, 256 lambda: blob.upload_from_filename( 257 str(file_path_obj), 258 content_type=content_type, 259 ) 260 ) 261 262 object_url = self._generate_public_url(full_object_name) 263 264 logger.info("Successfully uploaded file to GCS") 265 266 return UploadResult( 267 success=True, 268 object_name=full_object_name, 269 object_url=object_url, 270 file_size=file_size, 271 content_type=content_type, 272 metadata={ 273 "uploaded_at": time.time(), 274 "prefix": prefix, 275 **(safe_metadata or {}), 276 }, 277 ) 278 279 except StorageOperationError: 280 raise 281 except Exception as e: 282 logger.error(f"Failed to upload file: {e}") 283 return UploadResult( 284 success=False, 285 error=f"Upload failed: {e}", 286 error_code="UPLOAD_FAILED", 287 ) 288 289 async def download_file( 290 self, 291 object_name: str, 292 download_path: str, 293 ) -> DownloadResult: 294 """ 295 ファイルをGCSからダウンロード 296 297 Args: 298 object_name: ダウンロードするオブジェクト名 299 download_path: ダウンロード先のローカルパス 300 301 Returns: 302 DownloadResult: ダウンロード結果 303 """ 304 if not self._enabled: 305 return DownloadResult( 306 success=False, 307 error="GCS Storage is not enabled", 308 error_code="STORAGE_NOT_ENABLED", 309 ) 310 311 try: 312 bucket = self._get_bucket() 313 blob = bucket.blob(object_name) 314 315 exists = await asyncio.get_running_loop().run_in_executor( 316 None, blob.exists 317 ) 318 if not exists: 319 return DownloadResult( 320 success=False, 321 object_name=object_name, 322 error=f"Object not found: {object_name}", 323 error_code="OBJECT_NOT_FOUND", 324 ) 325 326 download_path_obj = Path(download_path) 327 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 328 329 await asyncio.get_running_loop().run_in_executor( 330 None, 331 lambda: blob.download_to_filename(str(download_path_obj)) 332 ) 333 334 file_size = download_path_obj.stat().st_size 335 logger.info("Successfully downloaded object from GCS") 336 337 return DownloadResult( 338 success=True, 339 object_name=object_name, 340 local_path=str(download_path_obj), 341 file_size=file_size, 342 ) 343 344 except Exception as e: 345 logger.error(f"Failed to download object: {e}") 346 return DownloadResult( 347 success=False, 348 object_name=object_name, 349 error=f"Download failed: {e}", 350 error_code="DOWNLOAD_FAILED", 351 ) 352 353 async def list_objects( 354 self, 355 prefix: str = "", 356 max_results: Optional[int] = None, 357 ) -> ListResult: 358 """ 359 オブジェクト一覧を取得 360 361 Args: 362 prefix: プレフィックスフィルタ 363 max_results: 最大取得数 364 365 Returns: 366 ListResult: 一覧結果 367 """ 368 if not self._enabled: 369 return ListResult( 370 success=False, 371 error="GCS Storage is not enabled", 372 error_code="STORAGE_NOT_ENABLED", 373 ) 374 375 try: 376 bucket = self._get_bucket() 377 378 list_kwargs = {} 379 if prefix: 380 list_kwargs['prefix'] = prefix 381 if max_results: 382 list_kwargs['max_results'] = max_results 383 384 def _list(): 385 blobs = list(bucket.list_blobs(**list_kwargs)) 386 objects = [] 387 for blob in blobs: 388 objects.append(ObjectInfo( 389 name=blob.name, 390 size=blob.size or 0, 391 last_modified=blob.updated.isoformat() if blob.updated else None, 392 content_type=blob.content_type, 393 )) 394 return objects 395 396 objects = await asyncio.get_running_loop().run_in_executor(None, _list) 397 398 logger.info(f"Listed {len(objects)} objects from GCS") 399 400 return ListResult( 401 success=True, 402 objects=objects, 403 count=len(objects), 404 prefix=prefix, 405 ) 406 407 except Exception as e: 408 logger.error(f"Failed to list objects: {e}") 409 return ListResult( 410 success=False, 411 error=f"List failed: {e}", 412 error_code="LIST_FAILED", 413 ) 414 415 async def delete_object(self, object_name: str) -> bool: 416 """オブジェクトを削除""" 417 if not self._enabled: 418 return False 419 420 try: 421 bucket = self._get_bucket() 422 blob = bucket.blob(object_name) 423 424 await asyncio.get_running_loop().run_in_executor( 425 None, blob.delete 426 ) 427 logger.info("Successfully deleted object from GCS") 428 return True 429 430 except Exception as e: 431 logger.error(f"Failed to delete object: {e}") 432 return False 433 434 async def object_exists(self, object_name: str) -> bool: 435 """オブジェクトの存在確認""" 436 if not self._enabled: 437 return False 438 439 try: 440 bucket = self._get_bucket() 441 blob = bucket.blob(object_name) 442 return await asyncio.get_running_loop().run_in_executor( 443 None, blob.exists 444 ) 445 except Exception as e: 446 logger.error(f"Failed to check object existence: {e}") 447 return False 448 449 async def get_object_url( 450 self, 451 object_name: str, 452 expires_in: Optional[int] = None, 453 ) -> str: 454 """ 455 署名付きURLを生成 456 457 Args: 458 object_name: オブジェクト名 459 expires_in: 有効期限(秒)。省略時は公開URLを返す 460 461 Returns: 462 署名付きURL or 公開URL 463 """ 464 if not self._enabled: 465 return "" 466 467 if expires_in: 468 try: 469 import datetime 470 bucket = self._get_bucket() 471 blob = bucket.blob(object_name) 472 url = await asyncio.get_running_loop().run_in_executor( 473 None, 474 lambda: blob.generate_signed_url( 475 expiration=datetime.timedelta(seconds=expires_in), 476 method='GET', 477 ) 478 ) 479 return url 480 except Exception as e: 481 logger.error(f"Failed to generate signed URL: {e}") 482 return "" 483 484 return self._generate_public_url(object_name) 485 486 # download_objects_by_prefix: StorageClientBase に移動済み 487 488 async def close(self) -> None: 489 """クライアントリソースをクローズ""" 490 try: 491 if self._gcs_client: 492 self._gcs_client.close() 493 except Exception as e: 494 logger.warning(f"Error closing GCS client: {e}") 495 finally: 496 self._gcs_client = None 497 self._bucket = None 498 self._enabled = False 499 logger.debug("GCSStorageClient closed")
Google Cloud Storageクライアント
Features:
- ファイルアップロード・ダウンロード
- オブジェクト一覧・削除
- 署名付きURL生成
- カスタムドメイン対応
- 日本語ファイル名対応(URLエンコード)
- メタデータサニタイズ(Base64エンコード)
Example: config = GCSConfig( bucket_name="my-bucket", project_id="my-project", credentials_path="/path/to/service-account.json", ) client = GCSStorageClient(config) result = await client.upload_file("/path/to/file.txt")
70 def __init__(self, config: GCSConfig): 71 """ 72 Initialize Google Cloud Storage Client 73 74 Args: 75 config: GCSConfig instance 76 77 Raises: 78 StorageConfigError: 設定が不正な場合 79 StorageConnectionError: 接続に失敗した場合 80 """ 81 super().__init__(config) 82 83 if not GCS_AVAILABLE: 84 raise StorageConfigError( 85 "google-cloud-storage package is not installed. " 86 "Install with: pip install google-cloud-storage" 87 ) 88 89 if not config.bucket_name: 90 raise StorageConfigError("GCS bucket_name is required") 91 92 if not config.project_id: 93 raise StorageConfigError("GCS project_id is required") 94 95 self._gcs_config: GCSConfig = config 96 self._gcs_client = None 97 self._bucket = None 98 99 try: 100 credentials_json = config.credentials_json 101 # Base64エンコード済みの場合はデコード (Helm TOML-safe format) 102 if not credentials_json and config.credentials_base64: 103 import base64 104 credentials_json = base64.b64decode(config.credentials_base64).decode('utf-8') 105 106 if config.credentials_path: 107 credentials = service_account.Credentials.from_service_account_file( 108 config.credentials_path 109 ) 110 self._gcs_client = gcs_storage.Client( 111 project=config.project_id, 112 credentials=credentials, 113 ) 114 elif credentials_json: 115 import json 116 credentials_info = json.loads(credentials_json) 117 credentials = service_account.Credentials.from_service_account_info( 118 credentials_info 119 ) 120 self._gcs_client = gcs_storage.Client( 121 project=config.project_id, 122 credentials=credentials, 123 ) 124 else: 125 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 126 self._gcs_client = gcs_storage.Client(project=config.project_id) 127 128 self._enabled = config.enabled 129 130 logger.info( 131 f"GCSStorageClient initialized - bucket={config.bucket_name}, " 132 f"project={config.project_id}" 133 ) 134 135 except Exception as e: 136 logger.error(f"Failed to initialize GCS client: {e}") 137 raise StorageConnectionError( 138 f"Failed to connect to GCS: {e}" 139 ) from e
Initialize Google Cloud Storage Client
Args: config: GCSConfig instance
Raises: StorageConfigError: 設定が不正な場合 StorageConnectionError: 接続に失敗した場合
160 async def ensure_bucket_exists(self) -> bool: 161 """バケットの存在確認・作成""" 162 if not self._enabled: 163 return False 164 165 try: 166 self._bucket = self._gcs_client.bucket(self.bucket_name) 167 168 exists = await asyncio.get_running_loop().run_in_executor( 169 None, self._bucket.exists 170 ) 171 if exists: 172 return True 173 174 if self._gcs_config.auto_create_bucket: 175 self._bucket = await asyncio.get_running_loop().run_in_executor( 176 None, 177 lambda: self._gcs_client.create_bucket(self.bucket_name) 178 ) 179 logger.info(f"Created GCS bucket: {self.bucket_name}") 180 return True 181 else: 182 logger.error(f"Bucket {self.bucket_name} does not exist") 183 return False 184 185 except Exception as e: 186 logger.error(f"Failed to ensure bucket exists: {e}") 187 return False
バケットの存在確認・作成
195 async def upload_file( 196 self, 197 file_path: str, 198 object_name: Optional[str] = None, 199 prefix: Optional[str] = None, 200 metadata: Optional[Dict[str, Any]] = None, 201 ) -> UploadResult: 202 """ 203 ファイルをGCSにアップロード 204 205 Args: 206 file_path: アップロードするファイルのローカルパス 207 object_name: オブジェクト名(省略時はファイル名が使用される) 208 prefix: オブジェクト名プレフィックス(フォルダ構造を作成) 209 metadata: 追加メタデータ(日本語はBase64エンコードされる) 210 211 Returns: 212 UploadResult: アップロード結果 213 """ 214 if not self._enabled: 215 return UploadResult( 216 success=False, 217 error="GCS Storage is not enabled", 218 error_code="STORAGE_NOT_ENABLED", 219 ) 220 221 try: 222 self._validate_file_size(file_path) 223 224 file_path_obj = Path(file_path) 225 226 if not object_name: 227 object_name = file_path_obj.name 228 229 safe_object_name = self._encode_object_name(object_name) 230 231 if prefix: 232 full_object_name = f"{prefix}/{safe_object_name}" 233 else: 234 full_object_name = safe_object_name 235 236 if not await self.ensure_bucket_exists(): 237 return UploadResult( 238 success=False, 239 error="Bucket is not available", 240 error_code="BUCKET_NOT_AVAILABLE", 241 ) 242 243 content_type = self._get_content_type(file_path) 244 safe_metadata = self._sanitize_metadata(metadata) 245 246 bucket = self._get_bucket() 247 blob = bucket.blob(full_object_name) 248 249 if safe_metadata: 250 blob.metadata = safe_metadata 251 252 file_size = file_path_obj.stat().st_size 253 254 await asyncio.get_running_loop().run_in_executor( 255 None, 256 lambda: blob.upload_from_filename( 257 str(file_path_obj), 258 content_type=content_type, 259 ) 260 ) 261 262 object_url = self._generate_public_url(full_object_name) 263 264 logger.info("Successfully uploaded file to GCS") 265 266 return UploadResult( 267 success=True, 268 object_name=full_object_name, 269 object_url=object_url, 270 file_size=file_size, 271 content_type=content_type, 272 metadata={ 273 "uploaded_at": time.time(), 274 "prefix": prefix, 275 **(safe_metadata or {}), 276 }, 277 ) 278 279 except StorageOperationError: 280 raise 281 except Exception as e: 282 logger.error(f"Failed to upload file: {e}") 283 return UploadResult( 284 success=False, 285 error=f"Upload failed: {e}", 286 error_code="UPLOAD_FAILED", 287 )
ファイルをGCSにアップロード
Args: file_path: アップロードするファイルのローカルパス object_name: オブジェクト名(省略時はファイル名が使用される) prefix: オブジェクト名プレフィックス(フォルダ構造を作成) metadata: 追加メタデータ(日本語はBase64エンコードされる)
Returns: UploadResult: アップロード結果
289 async def download_file( 290 self, 291 object_name: str, 292 download_path: str, 293 ) -> DownloadResult: 294 """ 295 ファイルをGCSからダウンロード 296 297 Args: 298 object_name: ダウンロードするオブジェクト名 299 download_path: ダウンロード先のローカルパス 300 301 Returns: 302 DownloadResult: ダウンロード結果 303 """ 304 if not self._enabled: 305 return DownloadResult( 306 success=False, 307 error="GCS Storage is not enabled", 308 error_code="STORAGE_NOT_ENABLED", 309 ) 310 311 try: 312 bucket = self._get_bucket() 313 blob = bucket.blob(object_name) 314 315 exists = await asyncio.get_running_loop().run_in_executor( 316 None, blob.exists 317 ) 318 if not exists: 319 return DownloadResult( 320 success=False, 321 object_name=object_name, 322 error=f"Object not found: {object_name}", 323 error_code="OBJECT_NOT_FOUND", 324 ) 325 326 download_path_obj = Path(download_path) 327 download_path_obj.parent.mkdir(parents=True, exist_ok=True) 328 329 await asyncio.get_running_loop().run_in_executor( 330 None, 331 lambda: blob.download_to_filename(str(download_path_obj)) 332 ) 333 334 file_size = download_path_obj.stat().st_size 335 logger.info("Successfully downloaded object from GCS") 336 337 return DownloadResult( 338 success=True, 339 object_name=object_name, 340 local_path=str(download_path_obj), 341 file_size=file_size, 342 ) 343 344 except Exception as e: 345 logger.error(f"Failed to download object: {e}") 346 return DownloadResult( 347 success=False, 348 object_name=object_name, 349 error=f"Download failed: {e}", 350 error_code="DOWNLOAD_FAILED", 351 )
ファイルをGCSからダウンロード
Args: object_name: ダウンロードするオブジェクト名 download_path: ダウンロード先のローカルパス
Returns: DownloadResult: ダウンロード結果
353 async def list_objects( 354 self, 355 prefix: str = "", 356 max_results: Optional[int] = None, 357 ) -> ListResult: 358 """ 359 オブジェクト一覧を取得 360 361 Args: 362 prefix: プレフィックスフィルタ 363 max_results: 最大取得数 364 365 Returns: 366 ListResult: 一覧結果 367 """ 368 if not self._enabled: 369 return ListResult( 370 success=False, 371 error="GCS Storage is not enabled", 372 error_code="STORAGE_NOT_ENABLED", 373 ) 374 375 try: 376 bucket = self._get_bucket() 377 378 list_kwargs = {} 379 if prefix: 380 list_kwargs['prefix'] = prefix 381 if max_results: 382 list_kwargs['max_results'] = max_results 383 384 def _list(): 385 blobs = list(bucket.list_blobs(**list_kwargs)) 386 objects = [] 387 for blob in blobs: 388 objects.append(ObjectInfo( 389 name=blob.name, 390 size=blob.size or 0, 391 last_modified=blob.updated.isoformat() if blob.updated else None, 392 content_type=blob.content_type, 393 )) 394 return objects 395 396 objects = await asyncio.get_running_loop().run_in_executor(None, _list) 397 398 logger.info(f"Listed {len(objects)} objects from GCS") 399 400 return ListResult( 401 success=True, 402 objects=objects, 403 count=len(objects), 404 prefix=prefix, 405 ) 406 407 except Exception as e: 408 logger.error(f"Failed to list objects: {e}") 409 return ListResult( 410 success=False, 411 error=f"List failed: {e}", 412 error_code="LIST_FAILED", 413 )
オブジェクト一覧を取得
Args: prefix: プレフィックスフィルタ max_results: 最大取得数
Returns: ListResult: 一覧結果
415 async def delete_object(self, object_name: str) -> bool: 416 """オブジェクトを削除""" 417 if not self._enabled: 418 return False 419 420 try: 421 bucket = self._get_bucket() 422 blob = bucket.blob(object_name) 423 424 await asyncio.get_running_loop().run_in_executor( 425 None, blob.delete 426 ) 427 logger.info("Successfully deleted object from GCS") 428 return True 429 430 except Exception as e: 431 logger.error(f"Failed to delete object: {e}") 432 return False
オブジェクトを削除
434 async def object_exists(self, object_name: str) -> bool: 435 """オブジェクトの存在確認""" 436 if not self._enabled: 437 return False 438 439 try: 440 bucket = self._get_bucket() 441 blob = bucket.blob(object_name) 442 return await asyncio.get_running_loop().run_in_executor( 443 None, blob.exists 444 ) 445 except Exception as e: 446 logger.error(f"Failed to check object existence: {e}") 447 return False
オブジェクトの存在確認
449 async def get_object_url( 450 self, 451 object_name: str, 452 expires_in: Optional[int] = None, 453 ) -> str: 454 """ 455 署名付きURLを生成 456 457 Args: 458 object_name: オブジェクト名 459 expires_in: 有効期限(秒)。省略時は公開URLを返す 460 461 Returns: 462 署名付きURL or 公開URL 463 """ 464 if not self._enabled: 465 return "" 466 467 if expires_in: 468 try: 469 import datetime 470 bucket = self._get_bucket() 471 blob = bucket.blob(object_name) 472 url = await asyncio.get_running_loop().run_in_executor( 473 None, 474 lambda: blob.generate_signed_url( 475 expiration=datetime.timedelta(seconds=expires_in), 476 method='GET', 477 ) 478 ) 479 return url 480 except Exception as e: 481 logger.error(f"Failed to generate signed URL: {e}") 482 return "" 483 484 return self._generate_public_url(object_name)
署名付きURLを生成
Args: object_name: オブジェクト名 expires_in: 有効期限(秒)。省略時は公開URLを返す
Returns: 署名付きURL or 公開URL
488 async def close(self) -> None: 489 """クライアントリソースをクローズ""" 490 try: 491 if self._gcs_client: 492 self._gcs_client.close() 493 except Exception as e: 494 logger.warning(f"Error closing GCS client: {e}") 495 finally: 496 self._gcs_client = None 497 self._bucket = None 498 self._enabled = False 499 logger.debug("GCSStorageClient closed")
クライアントリソースをクローズ
52@dataclass 53class UploadResult: 54 """アップロード結果""" 55 success: bool 56 object_name: str = "" 57 object_url: str = "" 58 file_size: int = 0 59 content_type: str = "" 60 error: Optional[str] = None 61 error_code: Optional[str] = None 62 metadata: Dict[str, Any] = field(default_factory=dict)
アップロード結果
65@dataclass 66class DownloadResult: 67 """ダウンロード結果""" 68 success: bool 69 object_name: str = "" 70 local_path: str = "" 71 file_size: int = 0 72 error: Optional[str] = None 73 error_code: Optional[str] = None
ダウンロード結果
76@dataclass 77class ObjectInfo: 78 """オブジェクト情報""" 79 name: str 80 size: int 81 last_modified: Optional[str] = None 82 content_type: Optional[str] = None 83 metadata: Dict[str, Any] = field(default_factory=dict)
オブジェクト情報
86@dataclass 87class ListResult: 88 """一覧取得結果""" 89 success: bool 90 objects: List[ObjectInfo] = field(default_factory=list) 91 count: int = 0 92 prefix: str = "" 93 error: Optional[str] = None 94 error_code: Optional[str] = None
一覧取得結果
21class SecurityProvider(str, Enum): 22 """セキュリティプロバイダー種別""" 23 AZURE = "azure" # Azure Content Safety + Language Service 24 AWS = "aws" # AWS Comprehend 25 GCP = "gcp" # Google Cloud DLP
セキュリティプロバイダー種別
28class ContentCategory(str, Enum): 29 """有害コンテンツカテゴリ(プロバイダー共通)""" 30 HATE = "hate" 31 SEXUAL = "sexual" 32 SELF_HARM = "self_harm" 33 VIOLENCE = "violence" 34 PROFANITY = "profanity" 35 INSULT = "insult" 36 THREAT = "threat"
有害コンテンツカテゴリ(プロバイダー共通)
39class PIICategory(str, Enum): 40 """PIIカテゴリ(プロバイダー共通・主要なもの)""" 41 # 個人識別情報 42 PERSON_NAME = "person_name" 43 EMAIL = "email" 44 PHONE_NUMBER = "phone_number" 45 ADDRESS = "address" 46 AGE = "age" 47 DATE_OF_BIRTH = "date_of_birth" 48 49 # 政府発行ID 50 NATIONAL_ID = "national_id" 51 PASSPORT_NUMBER = "passport_number" 52 DRIVERS_LICENSE = "drivers_license" 53 SOCIAL_SECURITY = "social_security" 54 TAX_ID = "tax_id" 55 56 # 金融情報 57 CREDIT_CARD = "credit_card" 58 BANK_ACCOUNT = "bank_account" 59 SWIFT_CODE = "swift_code" 60 IBAN = "iban" 61 62 # 医療情報 63 MEDICAL_RECORD = "medical_record" 64 HEALTH_INSURANCE = "health_insurance" 65 66 # 技術情報 67 IP_ADDRESS = "ip_address" 68 MAC_ADDRESS = "mac_address" 69 URL = "url" 70 71 # 認証情報 72 PASSWORD = "password" 73 API_KEY = "api_key" 74 AWS_ACCESS_KEY = "aws_access_key" 75 CONNECTION_STRING = "connection_string" 76 77 # その他 78 OTHER = "other"
PIIカテゴリ(プロバイダー共通・主要なもの)
Security操作の基底エラー
Security設定エラー
91class SecurityAPIError(SecurityError): 92 """Security API呼び出しエラー""" 93 94 def __init__(self, message: str, status_code: Optional[int] = None, error_code: str = "API_ERROR"): 95 super().__init__(message) 96 self.status_code = status_code 97 self.error_code = error_code
Security API呼び出しエラー
104@dataclass 105class ContentModerationResult: 106 """コンテンツモデレーション結果""" 107 blocked: bool 108 categories: Dict[ContentCategory, int] = field(default_factory=dict) # category -> severity (0-6) 109 threshold: int = 2 110 raw_response: Dict[str, Any] = field(default_factory=dict) 111 error: Optional[str] = None 112 error_code: Optional[str] = None
コンテンツモデレーション結果
115@dataclass 116class PromptShieldResult: 117 """プロンプトシールド(Jailbreak攻撃検出)結果""" 118 attack_detected: bool 119 attack_type: Optional[str] = None # jailbreak, injection, etc. 120 confidence: float = 0.0 121 raw_response: Dict[str, Any] = field(default_factory=dict) 122 error: Optional[str] = None 123 error_code: Optional[str] = None
プロンプトシールド(Jailbreak攻撃検出)結果
126@dataclass 127class PIIEntity: 128 """検出されたPIIエンティティ""" 129 category: PIICategory 130 text: str 131 offset: int 132 length: int 133 confidence: float 134 provider_category: str = "" # プロバイダー固有のカテゴリ名
検出されたPIIエンティティ
137@dataclass 138class PIIDetectionResult: 139 """PII検出結果""" 140 success: bool 141 masked_text: str = "" 142 entities: List[PIIEntity] = field(default_factory=list) 143 categories_detected: List[PIICategory] = field(default_factory=list) 144 raw_response: Dict[str, Any] = field(default_factory=dict) 145 error: Optional[str] = None 146 error_code: Optional[str] = None
PII検出結果
149@dataclass 150class SecurityCheckResult: 151 """統合セキュリティチェック結果""" 152 allowed: bool 153 violations: List[str] = field(default_factory=list) 154 content_moderation: Optional[ContentModerationResult] = None 155 prompt_shield: Optional[PromptShieldResult] = None 156 pii_detection: Optional[PIIDetectionResult] = None
統合セキュリティチェック結果
163@dataclass 164class AzureSecurityConfig: 165 """Azure Security設定 166 167 Note: 168 api_keyはrepr=Falseでログ出力から除外されます。 169 """ 170 # Content Safety 171 content_safety_endpoint: str 172 content_safety_api_key: str = field(repr=False) 173 174 # Language Service (PII) 175 language_endpoint: Optional[str] = None 176 language_api_key: Optional[str] = field(default=None, repr=False) 177 178 # 共通オプション 179 enabled: bool = True 180 prompt_shield_enabled: bool = True 181 moderation_enabled: bool = True 182 moderation_threshold: int = 2 # 0=disabled, 2=recommended, 4=loose, 6=extreme_only 183 pii_enabled: bool = True 184 pii_confidence_threshold: float = 0.7 185 pii_mask_string: str = "***" 186 timeout: float = 30.0 187 188 @property 189 def provider(self) -> SecurityProvider: 190 return SecurityProvider.AZURE 191 192 def __str__(self) -> str: 193 return ( 194 f"AzureSecurityConfig(content_safety_endpoint={self.content_safety_endpoint!r}, " 195 f"language_endpoint={self.language_endpoint!r}, api_keys=***)" 196 )
Azure Security設定
Note: api_keyはrepr=Falseでログ出力から除外されます。
199@dataclass 200class AWSSecurityConfig: 201 """AWS Comprehend設定 202 203 Warning: 204 AWSSecurityClientは現在スタブ実装です。 205 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。 206 207 Note: 208 aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。 209 """ 210 aws_access_key_id: str = field(repr=False) 211 aws_secret_access_key: str = field(repr=False) 212 region_name: str = "us-east-1" 213 214 # 共通オプション 215 # Warning: スタブ実装のためデフォルトはFalse 216 enabled: bool = False 217 moderation_enabled: bool = True 218 pii_enabled: bool = True 219 pii_confidence_threshold: float = 0.7 220 pii_mask_string: str = "***" 221 timeout: float = 30.0 222 223 @property 224 def provider(self) -> SecurityProvider: 225 return SecurityProvider.AWS 226 227 def __str__(self) -> str: 228 return f"AWSSecurityConfig(region_name={self.region_name!r}, credentials=***)"
AWS Comprehend設定
Warning: AWSSecurityClientは現在スタブ実装です。 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。
Note: aws_access_key_id, aws_secret_access_keyはrepr=Falseでログ出力から除外されます。
231@dataclass 232class GCPSecurityConfig: 233 """Google Cloud DLP設定 234 235 Warning: 236 GCPSecurityClientは現在スタブ実装です。 237 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。 238 239 Note: 240 credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。 241 """ 242 project_id: str 243 credentials_path: Optional[str] = field(default=None, repr=False) 244 credentials_json: Optional[str] = field(default=None, repr=False) 245 246 # 共通オプション 247 # Warning: スタブ実装のためデフォルトはFalse 248 enabled: bool = False 249 pii_enabled: bool = True 250 pii_confidence_threshold: float = 0.7 251 pii_mask_string: str = "***" 252 timeout: float = 30.0 253 254 @property 255 def provider(self) -> SecurityProvider: 256 return SecurityProvider.GCP 257 258 def __str__(self) -> str: 259 return f"GCPSecurityConfig(project_id={self.project_id!r}, credentials=***)"
Google Cloud DLP設定
Warning: GCPSecurityClientは現在スタブ実装です。 enabled=Falseがデフォルトです。本番環境ではAzureSecurityClientを使用してください。
Note: credentials_path, credentials_jsonはrepr=Falseでログ出力から除外されます。
107class AzureSecurityClient(SecurityClientBase): 108 """ 109 Azure Security クライアント 110 111 Azure Content Safety APIとLanguage Service APIを統合し、 112 コンテンツモデレーション、Jailbreak攻撃検出、PII検出・マスキングを提供。 113 114 Example: 115 config = AzureSecurityConfig( 116 content_safety_endpoint="https://xxx.cognitiveservices.azure.com", 117 content_safety_api_key="xxx", 118 ) 119 client = AzureSecurityClient(config) 120 121 # Content Moderation 122 result = await client.check_content_moderation("some text") 123 if result.blocked: 124 print("Content blocked") 125 126 # Prompt Shield 127 result = await client.check_prompt_shield("user prompt") 128 if result.attack_detected: 129 print("Attack detected") 130 131 # PII Detection 132 result = await client.detect_pii("My email is test@example.com") 133 print(result.masked_text) # "My email is ***" 134 135 await client.close() 136 """ 137 138 def __init__(self, config: AzureSecurityConfig): 139 """ 140 Initialize Azure Security Client 141 142 Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 143 両方未設定の場合のみエラー。 144 145 Args: 146 config: AzureSecurityConfig instance 147 148 Raises: 149 SecurityConfigError: Content SafetyとPII両方が未設定の場合 150 """ 151 super().__init__(config) 152 153 # Content SafetyまたはPIIのどちらかが設定されていればOK 154 has_content_safety = bool(config.content_safety_endpoint and config.content_safety_api_key) 155 has_pii = bool(config.language_endpoint and config.language_api_key) 156 157 if not has_content_safety and not has_pii: 158 raise SecurityConfigError( 159 "At least one of Content Safety or PII (Language Service) must be configured. " 160 "Provide content_safety_endpoint/api_key or language_endpoint/api_key." 161 ) 162 163 self._azure_config: AzureSecurityConfig = config 164 self._http_client: Optional[httpx.AsyncClient] = None 165 self._enabled = config.enabled 166 self._has_content_safety = has_content_safety 167 self._has_pii = has_pii 168 169 # ログ出力 170 cs_status = f"{config.content_safety_endpoint[:30]}..." if has_content_safety else "disabled" 171 pii_status = "enabled" if has_pii else "disabled" 172 logger.info( 173 f"AzureSecurityClient initialized - " 174 f"content_safety={cs_status}, " 175 f"pii={pii_status}" 176 ) 177 178 @property 179 def provider(self) -> SecurityProvider: 180 return SecurityProvider.AZURE 181 182 def _get_http_client(self) -> httpx.AsyncClient: 183 """HTTP clientを取得(遅延初期化)""" 184 if self._http_client is None or self._http_client.is_closed: 185 self._http_client = httpx.AsyncClient(timeout=self._azure_config.timeout) 186 return self._http_client 187 188 async def check_content_moderation( 189 self, 190 text: str, 191 threshold: Optional[int] = None, 192 ) -> ContentModerationResult: 193 """ 194 有害コンテンツをチェック(Azure Content Safety Moderation API) 195 196 Args: 197 text: チェック対象テキスト(最大10,000文字) 198 threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ) 199 200 Returns: 201 ContentModerationResult: モデレーション結果 202 """ 203 if not self._enabled: 204 return ContentModerationResult(blocked=False, error="Client disabled") 205 206 if not self._has_content_safety: 207 return ContentModerationResult(blocked=False, error="Content Safety not configured") 208 209 if not self._azure_config.moderation_enabled: 210 return ContentModerationResult(blocked=False, threshold=0) 211 212 if not text: 213 return ContentModerationResult(blocked=False) 214 215 effective_threshold = threshold if threshold is not None else self._azure_config.moderation_threshold 216 if effective_threshold == 0: 217 return ContentModerationResult(blocked=False, threshold=0) 218 219 try: 220 endpoint = ( 221 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 222 f"/contentsafety/text:analyze?api-version=2024-09-01" 223 ) 224 225 request_body = { 226 "text": text[:10000], 227 "categories": ["Hate", "Sexual", "SelfHarm", "Violence"], 228 "outputType": "FourSeverityLevels", 229 } 230 231 client = self._get_http_client() 232 response = await client.post( 233 endpoint, 234 headers={ 235 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 236 "Content-Type": "application/json", 237 }, 238 json=request_body, 239 ) 240 241 if response.status_code == 200: 242 result = response.json() 243 categories_analysis = result.get("categoriesAnalysis", []) 244 245 # カテゴリ別severity取得 246 categories: Dict[ContentCategory, int] = {} 247 blocked = False 248 249 for cat in categories_analysis: 250 category_name = cat.get("category", "").lower() 251 severity = cat.get("severity", 0) 252 253 normalized_cat = self._normalize_content_category(category_name) 254 categories[normalized_cat] = severity 255 256 if severity >= effective_threshold: 257 blocked = True 258 logger.warning( 259 f"Content moderation blocked: {category_name} " 260 f"severity={severity} >= threshold={effective_threshold}" 261 ) 262 263 return ContentModerationResult( 264 blocked=blocked, 265 categories=categories, 266 threshold=effective_threshold, 267 raw_response=result, 268 ) 269 else: 270 error_msg = f"Status {response.status_code}: {response.text}" 271 logger.error(f"Content Safety Moderation API error: {error_msg}") 272 raise SecurityAPIError(error_msg, response.status_code) 273 274 except SecurityAPIError: 275 raise 276 except httpx.RequestError as e: 277 logger.error(f"Content Safety Moderation request failed: {e}") 278 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 279 except Exception as e: 280 logger.error(f"Error in content moderation: {e}") 281 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR") 282 283 async def check_prompt_shield( 284 self, 285 user_prompt: str, 286 documents: Optional[List[str]] = None, 287 ) -> PromptShieldResult: 288 """ 289 Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API) 290 291 Args: 292 user_prompt: ユーザープロンプト(最大10,000文字) 293 documents: ドキュメント内容(オプション、外部コンテンツ検証用) 294 295 Returns: 296 PromptShieldResult: Prompt Shield結果 297 """ 298 if not self._enabled: 299 return PromptShieldResult(attack_detected=False, error="Client disabled") 300 301 if not self._has_content_safety: 302 return PromptShieldResult(attack_detected=False, error="Content Safety not configured") 303 304 if not self._azure_config.prompt_shield_enabled: 305 return PromptShieldResult(attack_detected=False) 306 307 if not user_prompt: 308 return PromptShieldResult(attack_detected=False) 309 310 try: 311 endpoint = ( 312 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 313 f"/contentsafety/text:shieldPrompt?api-version=2024-09-01" 314 ) 315 316 request_body: Dict[str, Any] = { 317 "userPrompt": user_prompt[:10000], 318 } 319 320 if documents: 321 request_body["documents"] = documents 322 323 client = self._get_http_client() 324 response = await client.post( 325 endpoint, 326 headers={ 327 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 328 "Content-Type": "application/json", 329 }, 330 json=request_body, 331 ) 332 333 if response.status_code == 200: 334 result = response.json() 335 user_analysis = result.get("userPromptAnalysis", {}) 336 attack_detected = user_analysis.get("attackDetected", False) 337 338 attack_type = None 339 if attack_detected: 340 attack_type = "jailbreak" 341 logger.warning("Prompt Shield detected jailbreak attack") 342 343 return PromptShieldResult( 344 attack_detected=attack_detected, 345 attack_type=attack_type, 346 raw_response=result, 347 ) 348 else: 349 error_msg = f"Status {response.status_code}: {response.text}" 350 logger.error(f"Prompt Shield API error: {error_msg}") 351 raise SecurityAPIError(error_msg, response.status_code) 352 353 except SecurityAPIError: 354 raise 355 except httpx.RequestError as e: 356 logger.error(f"Prompt Shield request failed: {e}") 357 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 358 except Exception as e: 359 logger.error(f"Error in prompt shield: {e}") 360 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR") 361 362 async def detect_pii( 363 self, 364 text: str, 365 mask: bool = True, 366 ) -> PIIDetectionResult: 367 """ 368 PIIを検出・マスク(Azure Language Service PII Detection API) 369 370 Args: 371 text: チェック対象テキスト(最大125,000文字) 372 mask: マスク処理を行うかどうか 373 374 Returns: 375 PIIDetectionResult: PII検出結果 376 """ 377 if not self._enabled: 378 return PIIDetectionResult(success=False, error="Client disabled") 379 380 if not self._has_pii: 381 return PIIDetectionResult(success=False, error="PII (Language Service) not configured") 382 383 if not self._azure_config.pii_enabled: 384 return PIIDetectionResult(success=True, masked_text=text) 385 386 if not text: 387 return PIIDetectionResult(success=True, masked_text="") 388 389 try: 390 endpoint = ( 391 f"{self._azure_config.language_endpoint.rstrip('/')}" 392 f"/language/:analyze-text?api-version=2024-11-01" 393 ) 394 395 request_body = { 396 "kind": "PiiEntityRecognition", 397 "parameters": { 398 "modelVersion": "latest", 399 "domain": "none", 400 "loggingOptOut": True, 401 "piiCategories": AZURE_PII_CATEGORIES, 402 }, 403 "analysisInput": { 404 "documents": [ 405 { 406 "id": "1", 407 "language": "ja", 408 "text": text[:125000], 409 } 410 ] 411 }, 412 } 413 414 client = self._get_http_client() 415 response = await client.post( 416 endpoint, 417 headers={ 418 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.language_api_key), 419 "Content-Type": "application/json", 420 }, 421 json=request_body, 422 ) 423 424 if response.status_code == 200: 425 result = response.json() 426 documents = result.get("results", {}).get("documents", []) 427 428 if not documents: 429 return PIIDetectionResult( 430 success=True, 431 masked_text="" if mask else text, # データ保護優先 432 ) 433 434 doc_result = documents[0] 435 entities_raw = doc_result.get("entities", []) 436 437 # エンティティをPIIEntityに変換 438 entities: List[PIIEntity] = [] 439 categories_detected: List[PIICategory] = [] 440 441 for entity in entities_raw: 442 confidence = entity.get("confidenceScore", 0) 443 if confidence < self._azure_config.pii_confidence_threshold: 444 continue 445 446 provider_category = entity.get("category", "Unknown") 447 normalized_category = self._normalize_pii_category(provider_category) 448 449 entities.append(PIIEntity( 450 category=normalized_category, 451 text=entity.get("text", ""), 452 offset=entity.get("offset", 0), 453 length=entity.get("length", 0), 454 confidence=confidence, 455 provider_category=provider_category, 456 )) 457 458 if normalized_category not in categories_detected: 459 categories_detected.append(normalized_category) 460 461 # マスク処理 462 masked_text = text 463 if mask and entities: 464 # オフセット順で降順ソート(後ろから処理) 465 sorted_entities = sorted(entities, key=lambda x: x.offset, reverse=True) 466 for entity in sorted_entities: 467 masked_text = ( 468 masked_text[:entity.offset] 469 + self._azure_config.pii_mask_string 470 + masked_text[entity.offset + entity.length:] 471 ) 472 473 if entities: 474 logger.info(f"PII detected and masked: {categories_detected}") 475 476 return PIIDetectionResult( 477 success=True, 478 masked_text=masked_text, 479 entities=entities, 480 categories_detected=categories_detected, 481 raw_response=result, 482 ) 483 484 else: 485 error_msg = f"Status {response.status_code}: {response.text}" 486 logger.error(f"Language Service PII API error: {error_msg}") 487 # PII検出エラー時は空文字を返す(データ保護優先) 488 return PIIDetectionResult( 489 success=False, 490 masked_text="", 491 error=error_msg, 492 error_code="API_ERROR", 493 ) 494 495 except httpx.RequestError as e: 496 logger.error(f"Language Service PII request failed: {e}") 497 return PIIDetectionResult( 498 success=False, 499 masked_text="", 500 error=str(e), 501 error_code="REQUEST_ERROR", 502 ) 503 except Exception as e: 504 logger.error(f"Error in PII detection: {e}") 505 return PIIDetectionResult( 506 success=False, 507 masked_text="", 508 error=str(e), 509 error_code="UNKNOWN_ERROR", 510 ) 511 512 async def close(self) -> None: 513 """クライアントリソースをクローズ 514 515 Note: 516 例外が発生してもリソースを確実にクリーンアップします。 517 """ 518 try: 519 if self._http_client and not self._http_client.is_closed: 520 await self._http_client.aclose() 521 except Exception as e: 522 logger.warning(f"Error closing HTTP client: {e}") 523 finally: 524 self._http_client = None 525 self._enabled = False 526 logger.debug("AzureSecurityClient closed")
Azure Security クライアント
Azure Content Safety APIとLanguage Service APIを統合し、 コンテンツモデレーション、Jailbreak攻撃検出、PII検出・マスキングを提供。
Example: config = AzureSecurityConfig( content_safety_endpoint="https://xxx.cognitiveservices.azure.com", content_safety_api_key="xxx", ) client = AzureSecurityClient(config)
# Content Moderation
result = await client.check_content_moderation("some text")
if result.blocked:
print("Content blocked")
# Prompt Shield
result = await client.check_prompt_shield("user prompt")
if result.attack_detected:
print("Attack detected")
# PII Detection
result = await client.detect_pii("My email is test@example.com")
print(result.masked_text) # "My email is ***"
await client.close()
138 def __init__(self, config: AzureSecurityConfig): 139 """ 140 Initialize Azure Security Client 141 142 Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 143 両方未設定の場合のみエラー。 144 145 Args: 146 config: AzureSecurityConfig instance 147 148 Raises: 149 SecurityConfigError: Content SafetyとPII両方が未設定の場合 150 """ 151 super().__init__(config) 152 153 # Content SafetyまたはPIIのどちらかが設定されていればOK 154 has_content_safety = bool(config.content_safety_endpoint and config.content_safety_api_key) 155 has_pii = bool(config.language_endpoint and config.language_api_key) 156 157 if not has_content_safety and not has_pii: 158 raise SecurityConfigError( 159 "At least one of Content Safety or PII (Language Service) must be configured. " 160 "Provide content_safety_endpoint/api_key or language_endpoint/api_key." 161 ) 162 163 self._azure_config: AzureSecurityConfig = config 164 self._http_client: Optional[httpx.AsyncClient] = None 165 self._enabled = config.enabled 166 self._has_content_safety = has_content_safety 167 self._has_pii = has_pii 168 169 # ログ出力 170 cs_status = f"{config.content_safety_endpoint[:30]}..." if has_content_safety else "disabled" 171 pii_status = "enabled" if has_pii else "disabled" 172 logger.info( 173 f"AzureSecurityClient initialized - " 174 f"content_safety={cs_status}, " 175 f"pii={pii_status}" 176 )
Initialize Azure Security Client
Content SafetyまたはPII(Language Service)のどちらか一方でも設定があれば初期化可能。 両方未設定の場合のみエラー。
Args: config: AzureSecurityConfig instance
Raises: SecurityConfigError: Content SafetyとPII両方が未設定の場合
188 async def check_content_moderation( 189 self, 190 text: str, 191 threshold: Optional[int] = None, 192 ) -> ContentModerationResult: 193 """ 194 有害コンテンツをチェック(Azure Content Safety Moderation API) 195 196 Args: 197 text: チェック対象テキスト(最大10,000文字) 198 threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ) 199 200 Returns: 201 ContentModerationResult: モデレーション結果 202 """ 203 if not self._enabled: 204 return ContentModerationResult(blocked=False, error="Client disabled") 205 206 if not self._has_content_safety: 207 return ContentModerationResult(blocked=False, error="Content Safety not configured") 208 209 if not self._azure_config.moderation_enabled: 210 return ContentModerationResult(blocked=False, threshold=0) 211 212 if not text: 213 return ContentModerationResult(blocked=False) 214 215 effective_threshold = threshold if threshold is not None else self._azure_config.moderation_threshold 216 if effective_threshold == 0: 217 return ContentModerationResult(blocked=False, threshold=0) 218 219 try: 220 endpoint = ( 221 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 222 f"/contentsafety/text:analyze?api-version=2024-09-01" 223 ) 224 225 request_body = { 226 "text": text[:10000], 227 "categories": ["Hate", "Sexual", "SelfHarm", "Violence"], 228 "outputType": "FourSeverityLevels", 229 } 230 231 client = self._get_http_client() 232 response = await client.post( 233 endpoint, 234 headers={ 235 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 236 "Content-Type": "application/json", 237 }, 238 json=request_body, 239 ) 240 241 if response.status_code == 200: 242 result = response.json() 243 categories_analysis = result.get("categoriesAnalysis", []) 244 245 # カテゴリ別severity取得 246 categories: Dict[ContentCategory, int] = {} 247 blocked = False 248 249 for cat in categories_analysis: 250 category_name = cat.get("category", "").lower() 251 severity = cat.get("severity", 0) 252 253 normalized_cat = self._normalize_content_category(category_name) 254 categories[normalized_cat] = severity 255 256 if severity >= effective_threshold: 257 blocked = True 258 logger.warning( 259 f"Content moderation blocked: {category_name} " 260 f"severity={severity} >= threshold={effective_threshold}" 261 ) 262 263 return ContentModerationResult( 264 blocked=blocked, 265 categories=categories, 266 threshold=effective_threshold, 267 raw_response=result, 268 ) 269 else: 270 error_msg = f"Status {response.status_code}: {response.text}" 271 logger.error(f"Content Safety Moderation API error: {error_msg}") 272 raise SecurityAPIError(error_msg, response.status_code) 273 274 except SecurityAPIError: 275 raise 276 except httpx.RequestError as e: 277 logger.error(f"Content Safety Moderation request failed: {e}") 278 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 279 except Exception as e: 280 logger.error(f"Error in content moderation: {e}") 281 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR")
有害コンテンツをチェック(Azure Content Safety Moderation API)
Args: text: チェック対象テキスト(最大10,000文字) threshold: しきい値(0=無効, 2=推奨, 4=緩い, 6=極度のみ)
Returns: ContentModerationResult: モデレーション結果
283 async def check_prompt_shield( 284 self, 285 user_prompt: str, 286 documents: Optional[List[str]] = None, 287 ) -> PromptShieldResult: 288 """ 289 Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API) 290 291 Args: 292 user_prompt: ユーザープロンプト(最大10,000文字) 293 documents: ドキュメント内容(オプション、外部コンテンツ検証用) 294 295 Returns: 296 PromptShieldResult: Prompt Shield結果 297 """ 298 if not self._enabled: 299 return PromptShieldResult(attack_detected=False, error="Client disabled") 300 301 if not self._has_content_safety: 302 return PromptShieldResult(attack_detected=False, error="Content Safety not configured") 303 304 if not self._azure_config.prompt_shield_enabled: 305 return PromptShieldResult(attack_detected=False) 306 307 if not user_prompt: 308 return PromptShieldResult(attack_detected=False) 309 310 try: 311 endpoint = ( 312 f"{self._azure_config.content_safety_endpoint.rstrip('/')}" 313 f"/contentsafety/text:shieldPrompt?api-version=2024-09-01" 314 ) 315 316 request_body: Dict[str, Any] = { 317 "userPrompt": user_prompt[:10000], 318 } 319 320 if documents: 321 request_body["documents"] = documents 322 323 client = self._get_http_client() 324 response = await client.post( 325 endpoint, 326 headers={ 327 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.content_safety_api_key), 328 "Content-Type": "application/json", 329 }, 330 json=request_body, 331 ) 332 333 if response.status_code == 200: 334 result = response.json() 335 user_analysis = result.get("userPromptAnalysis", {}) 336 attack_detected = user_analysis.get("attackDetected", False) 337 338 attack_type = None 339 if attack_detected: 340 attack_type = "jailbreak" 341 logger.warning("Prompt Shield detected jailbreak attack") 342 343 return PromptShieldResult( 344 attack_detected=attack_detected, 345 attack_type=attack_type, 346 raw_response=result, 347 ) 348 else: 349 error_msg = f"Status {response.status_code}: {response.text}" 350 logger.error(f"Prompt Shield API error: {error_msg}") 351 raise SecurityAPIError(error_msg, response.status_code) 352 353 except SecurityAPIError: 354 raise 355 except httpx.RequestError as e: 356 logger.error(f"Prompt Shield request failed: {e}") 357 raise SecurityAPIError(str(e), error_code="REQUEST_ERROR") 358 except Exception as e: 359 logger.error(f"Error in prompt shield: {e}") 360 raise SecurityAPIError(str(e), error_code="UNKNOWN_ERROR")
Jailbreak攻撃を検出(Azure Content Safety Prompt Shield API)
Args: user_prompt: ユーザープロンプト(最大10,000文字) documents: ドキュメント内容(オプション、外部コンテンツ検証用)
Returns: PromptShieldResult: Prompt Shield結果
362 async def detect_pii( 363 self, 364 text: str, 365 mask: bool = True, 366 ) -> PIIDetectionResult: 367 """ 368 PIIを検出・マスク(Azure Language Service PII Detection API) 369 370 Args: 371 text: チェック対象テキスト(最大125,000文字) 372 mask: マスク処理を行うかどうか 373 374 Returns: 375 PIIDetectionResult: PII検出結果 376 """ 377 if not self._enabled: 378 return PIIDetectionResult(success=False, error="Client disabled") 379 380 if not self._has_pii: 381 return PIIDetectionResult(success=False, error="PII (Language Service) not configured") 382 383 if not self._azure_config.pii_enabled: 384 return PIIDetectionResult(success=True, masked_text=text) 385 386 if not text: 387 return PIIDetectionResult(success=True, masked_text="") 388 389 try: 390 endpoint = ( 391 f"{self._azure_config.language_endpoint.rstrip('/')}" 392 f"/language/:analyze-text?api-version=2024-11-01" 393 ) 394 395 request_body = { 396 "kind": "PiiEntityRecognition", 397 "parameters": { 398 "modelVersion": "latest", 399 "domain": "none", 400 "loggingOptOut": True, 401 "piiCategories": AZURE_PII_CATEGORIES, 402 }, 403 "analysisInput": { 404 "documents": [ 405 { 406 "id": "1", 407 "language": "ja", 408 "text": text[:125000], 409 } 410 ] 411 }, 412 } 413 414 client = self._get_http_client() 415 response = await client.post( 416 endpoint, 417 headers={ 418 "Ocp-Apim-Subscription-Key": self._get_api_key(self._azure_config.language_api_key), 419 "Content-Type": "application/json", 420 }, 421 json=request_body, 422 ) 423 424 if response.status_code == 200: 425 result = response.json() 426 documents = result.get("results", {}).get("documents", []) 427 428 if not documents: 429 return PIIDetectionResult( 430 success=True, 431 masked_text="" if mask else text, # データ保護優先 432 ) 433 434 doc_result = documents[0] 435 entities_raw = doc_result.get("entities", []) 436 437 # エンティティをPIIEntityに変換 438 entities: List[PIIEntity] = [] 439 categories_detected: List[PIICategory] = [] 440 441 for entity in entities_raw: 442 confidence = entity.get("confidenceScore", 0) 443 if confidence < self._azure_config.pii_confidence_threshold: 444 continue 445 446 provider_category = entity.get("category", "Unknown") 447 normalized_category = self._normalize_pii_category(provider_category) 448 449 entities.append(PIIEntity( 450 category=normalized_category, 451 text=entity.get("text", ""), 452 offset=entity.get("offset", 0), 453 length=entity.get("length", 0), 454 confidence=confidence, 455 provider_category=provider_category, 456 )) 457 458 if normalized_category not in categories_detected: 459 categories_detected.append(normalized_category) 460 461 # マスク処理 462 masked_text = text 463 if mask and entities: 464 # オフセット順で降順ソート(後ろから処理) 465 sorted_entities = sorted(entities, key=lambda x: x.offset, reverse=True) 466 for entity in sorted_entities: 467 masked_text = ( 468 masked_text[:entity.offset] 469 + self._azure_config.pii_mask_string 470 + masked_text[entity.offset + entity.length:] 471 ) 472 473 if entities: 474 logger.info(f"PII detected and masked: {categories_detected}") 475 476 return PIIDetectionResult( 477 success=True, 478 masked_text=masked_text, 479 entities=entities, 480 categories_detected=categories_detected, 481 raw_response=result, 482 ) 483 484 else: 485 error_msg = f"Status {response.status_code}: {response.text}" 486 logger.error(f"Language Service PII API error: {error_msg}") 487 # PII検出エラー時は空文字を返す(データ保護優先) 488 return PIIDetectionResult( 489 success=False, 490 masked_text="", 491 error=error_msg, 492 error_code="API_ERROR", 493 ) 494 495 except httpx.RequestError as e: 496 logger.error(f"Language Service PII request failed: {e}") 497 return PIIDetectionResult( 498 success=False, 499 masked_text="", 500 error=str(e), 501 error_code="REQUEST_ERROR", 502 ) 503 except Exception as e: 504 logger.error(f"Error in PII detection: {e}") 505 return PIIDetectionResult( 506 success=False, 507 masked_text="", 508 error=str(e), 509 error_code="UNKNOWN_ERROR", 510 )
PIIを検出・マスク(Azure Language Service PII Detection API)
Args: text: チェック対象テキスト(最大125,000文字) mask: マスク処理を行うかどうか
Returns: PIIDetectionResult: PII検出結果
512 async def close(self) -> None: 513 """クライアントリソースをクローズ 514 515 Note: 516 例外が発生してもリソースを確実にクリーンアップします。 517 """ 518 try: 519 if self._http_client and not self._http_client.is_closed: 520 await self._http_client.aclose() 521 except Exception as e: 522 logger.warning(f"Error closing HTTP client: {e}") 523 finally: 524 self._http_client = None 525 self._enabled = False 526 logger.debug("AzureSecurityClient closed")
クライアントリソースをクローズ
Note: 例外が発生してもリソースを確実にクリーンアップします。
58class AWSSecurityClient(SecurityClientBase): 59 """ 60 AWS Comprehend セキュリティクライアント(スタブ実装) 61 62 Note: 63 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 64 detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。 65 66 Example: 67 config = AWSSecurityConfig( 68 aws_access_key_id="AKIA...", 69 aws_secret_access_key="...", 70 region_name="us-east-1", 71 ) 72 client = AWSSecurityClient(config) 73 result = await client.detect_pii("My email is test@example.com") 74 """ 75 76 def __init__(self, config: AWSSecurityConfig): 77 """ 78 Initialize AWS Comprehend Security Client 79 80 Args: 81 config: AWSSecurityConfig instance 82 83 Raises: 84 SecurityConfigError: 設定が不正な場合 85 """ 86 super().__init__(config) 87 88 if not BOTO3_AVAILABLE: 89 raise SecurityConfigError( 90 "boto3 package is not installed. " 91 "Install with: pip install boto3" 92 ) 93 94 if not config.aws_access_key_id or not config.aws_secret_access_key: 95 raise SecurityConfigError( 96 "AWS requires aws_access_key_id and aws_secret_access_key" 97 ) 98 99 self._aws_config: AWSSecurityConfig = config 100 self._comprehend_client = None 101 102 try: 103 session = boto3.Session( 104 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 105 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 106 region_name=config.region_name, 107 ) 108 109 self._comprehend_client = session.client('comprehend') 110 self._enabled = config.enabled 111 112 logger.warning( 113 "AWSSecurityClient is a STUB implementation. " 114 "detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。" 115 ) 116 logger.info( 117 f"AWSSecurityClient initialized - region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize AWS Comprehend client: {e}") 122 raise SecurityConfigError(f"Failed to connect to AWS Comprehend: {e}") from e 123 124 @property 125 def provider(self) -> SecurityProvider: 126 return SecurityProvider.AWS 127 128 async def check_content_moderation( 129 self, 130 text: str, 131 threshold: Optional[int] = None, 132 ) -> ContentModerationResult: 133 """ 134 有害コンテンツをチェック(AWS Comprehend DetectToxicContent) 135 136 Note: 137 スタブ実装 - NOT_IMPLEMENTEDを返します 138 139 TODO: 140 - DetectToxicContent APIの実装 141 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等 142 """ 143 # TODO: 本格実装 144 # aws comprehend detect-toxic-content --text-segments "Text=<text>" --language-code ja 145 return ContentModerationResult( 146 blocked=False, 147 error="AWS Comprehend content moderation not yet implemented", 148 error_code="NOT_IMPLEMENTED", 149 ) 150 151 async def check_prompt_shield( 152 self, 153 user_prompt: str, 154 documents: Optional[List[str]] = None, 155 ) -> PromptShieldResult: 156 """ 157 Jailbreak攻撃を検出 158 159 Note: 160 AWS ComprehendにはPrompt Shield相当の機能がないため、 161 常にattack_detected=Falseを返します。 162 """ 163 # AWS Comprehendには直接的なPrompt Shield機能がない 164 return PromptShieldResult( 165 attack_detected=False, 166 error="AWS Comprehend does not support Prompt Shield functionality", 167 error_code="NOT_SUPPORTED", 168 ) 169 170 async def detect_pii( 171 self, 172 text: str, 173 mask: bool = True, 174 ) -> PIIDetectionResult: 175 """ 176 PIIを検出・マスク(AWS Comprehend DetectPiiEntities) 177 178 Note: 179 スタブ実装 - NOT_IMPLEMENTEDを返します 180 181 TODO: 182 - DetectPiiEntities APIの実装 183 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, 184 CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, 185 EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, 186 PIN, SSN, URL, USERNAME等 187 """ 188 # TODO: 本格実装 189 # import asyncio 190 # response = await asyncio.get_event_loop().run_in_executor( 191 # None, 192 # lambda: self._comprehend_client.detect_pii_entities( 193 # Text=text[:5000], # API制限 194 # LanguageCode='ja' 195 # ) 196 # ) 197 return PIIDetectionResult( 198 success=False, 199 masked_text="", 200 error="AWS Comprehend PII detection not yet implemented", 201 error_code="NOT_IMPLEMENTED", 202 ) 203 204 async def close(self) -> None: 205 """クライアントリソースをクローズ 206 207 Note: 208 boto3クライアントは明示的なclose不要ですが、 209 参照をクリアしてリソースを確実に解放します。 210 """ 211 try: 212 # boto3クライアントは明示的なclose不要だが、参照をクリア 213 pass 214 except Exception as e: 215 logger.warning(f"Error during cleanup: {e}") 216 finally: 217 self._comprehend_client = None 218 self._enabled = False 219 logger.debug("AWSSecurityClient closed")
AWS Comprehend セキュリティクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。
Example: config = AWSSecurityConfig( aws_access_key_id="AKIA...", aws_secret_access_key="...", region_name="us-east-1", ) client = AWSSecurityClient(config) result = await client.detect_pii("My email is test@example.com")
76 def __init__(self, config: AWSSecurityConfig): 77 """ 78 Initialize AWS Comprehend Security Client 79 80 Args: 81 config: AWSSecurityConfig instance 82 83 Raises: 84 SecurityConfigError: 設定が不正な場合 85 """ 86 super().__init__(config) 87 88 if not BOTO3_AVAILABLE: 89 raise SecurityConfigError( 90 "boto3 package is not installed. " 91 "Install with: pip install boto3" 92 ) 93 94 if not config.aws_access_key_id or not config.aws_secret_access_key: 95 raise SecurityConfigError( 96 "AWS requires aws_access_key_id and aws_secret_access_key" 97 ) 98 99 self._aws_config: AWSSecurityConfig = config 100 self._comprehend_client = None 101 102 try: 103 session = boto3.Session( 104 aws_access_key_id=self._get_api_key(config.aws_access_key_id), 105 aws_secret_access_key=self._get_api_key(config.aws_secret_access_key), 106 region_name=config.region_name, 107 ) 108 109 self._comprehend_client = session.client('comprehend') 110 self._enabled = config.enabled 111 112 logger.warning( 113 "AWSSecurityClient is a STUB implementation. " 114 "detect_pii, check_content_moderation等のメソッドはNOT_IMPLEMENTEDを返します。" 115 ) 116 logger.info( 117 f"AWSSecurityClient initialized - region={config.region_name}" 118 ) 119 120 except Exception as e: 121 logger.error(f"Failed to initialize AWS Comprehend client: {e}") 122 raise SecurityConfigError(f"Failed to connect to AWS Comprehend: {e}") from e
Initialize AWS Comprehend Security Client
Args: config: AWSSecurityConfig instance
Raises: SecurityConfigError: 設定が不正な場合
128 async def check_content_moderation( 129 self, 130 text: str, 131 threshold: Optional[int] = None, 132 ) -> ContentModerationResult: 133 """ 134 有害コンテンツをチェック(AWS Comprehend DetectToxicContent) 135 136 Note: 137 スタブ実装 - NOT_IMPLEMENTEDを返します 138 139 TODO: 140 - DetectToxicContent APIの実装 141 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等 142 """ 143 # TODO: 本格実装 144 # aws comprehend detect-toxic-content --text-segments "Text=<text>" --language-code ja 145 return ContentModerationResult( 146 blocked=False, 147 error="AWS Comprehend content moderation not yet implemented", 148 error_code="NOT_IMPLEMENTED", 149 )
有害コンテンツをチェック(AWS Comprehend DetectToxicContent)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - DetectToxicContent APIの実装 - カテゴリマッピング: HATE_SPEECH, SEXUAL, VIOLENCE_OR_THREAT等
151 async def check_prompt_shield( 152 self, 153 user_prompt: str, 154 documents: Optional[List[str]] = None, 155 ) -> PromptShieldResult: 156 """ 157 Jailbreak攻撃を検出 158 159 Note: 160 AWS ComprehendにはPrompt Shield相当の機能がないため、 161 常にattack_detected=Falseを返します。 162 """ 163 # AWS Comprehendには直接的なPrompt Shield機能がない 164 return PromptShieldResult( 165 attack_detected=False, 166 error="AWS Comprehend does not support Prompt Shield functionality", 167 error_code="NOT_SUPPORTED", 168 )
Jailbreak攻撃を検出
Note: AWS ComprehendにはPrompt Shield相当の機能がないため、 常にattack_detected=Falseを返します。
170 async def detect_pii( 171 self, 172 text: str, 173 mask: bool = True, 174 ) -> PIIDetectionResult: 175 """ 176 PIIを検出・マスク(AWS Comprehend DetectPiiEntities) 177 178 Note: 179 スタブ実装 - NOT_IMPLEMENTEDを返します 180 181 TODO: 182 - DetectPiiEntities APIの実装 183 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, 184 CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, 185 EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, 186 PIN, SSN, URL, USERNAME等 187 """ 188 # TODO: 本格実装 189 # import asyncio 190 # response = await asyncio.get_event_loop().run_in_executor( 191 # None, 192 # lambda: self._comprehend_client.detect_pii_entities( 193 # Text=text[:5000], # API制限 194 # LanguageCode='ja' 195 # ) 196 # ) 197 return PIIDetectionResult( 198 success=False, 199 masked_text="", 200 error="AWS Comprehend PII detection not yet implemented", 201 error_code="NOT_IMPLEMENTED", 202 )
PIIを検出・マスク(AWS Comprehend DetectPiiEntities)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - DetectPiiEntities APIの実装 - サポートカテゴリ: BANK_ACCOUNT_NUMBER, BANK_ROUTING, CREDIT_DEBIT_CVV, CREDIT_DEBIT_EXPIRY, CREDIT_DEBIT_NUMBER, EMAIL, IP_ADDRESS, MAC_ADDRESS, NAME, PASSWORD, PHONE, PIN, SSN, URL, USERNAME等
204 async def close(self) -> None: 205 """クライアントリソースをクローズ 206 207 Note: 208 boto3クライアントは明示的なclose不要ですが、 209 参照をクリアしてリソースを確実に解放します。 210 """ 211 try: 212 # boto3クライアントは明示的なclose不要だが、参照をクリア 213 pass 214 except Exception as e: 215 logger.warning(f"Error during cleanup: {e}") 216 finally: 217 self._comprehend_client = None 218 self._enabled = False 219 logger.debug("AWSSecurityClient closed")
クライアントリソースをクローズ
Note: boto3クライアントは明示的なclose不要ですが、 参照をクリアしてリソースを確実に解放します。
58class GCPSecurityClient(SecurityClientBase): 59 """ 60 Google Cloud DLP セキュリティクライアント(スタブ実装) 61 62 Note: 63 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 64 detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。 65 66 Example: 67 config = GCPSecurityConfig( 68 project_id="my-project", 69 credentials_path="/path/to/service-account.json", 70 ) 71 client = GCPSecurityClient(config) 72 result = await client.detect_pii("My email is test@example.com") 73 """ 74 75 def __init__(self, config: GCPSecurityConfig): 76 """ 77 Initialize Google Cloud DLP Security Client 78 79 Args: 80 config: GCPSecurityConfig instance 81 82 Raises: 83 SecurityConfigError: 設定が不正な場合 84 """ 85 super().__init__(config) 86 87 if not GCP_DLP_AVAILABLE: 88 raise SecurityConfigError( 89 "google-cloud-dlp package is not installed. " 90 "Install with: pip install google-cloud-dlp" 91 ) 92 93 if not config.project_id: 94 raise SecurityConfigError("GCP project_id is required") 95 96 self._gcp_config: GCPSecurityConfig = config 97 self._dlp_client = None 98 99 try: 100 # 認証情報の設定 101 if config.credentials_path: 102 credentials = service_account.Credentials.from_service_account_file( 103 config.credentials_path 104 ) 105 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 106 elif config.credentials_json: 107 import json 108 credentials_info = json.loads(config.credentials_json) 109 credentials = service_account.Credentials.from_service_account_info( 110 credentials_info 111 ) 112 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 113 else: 114 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 115 self._dlp_client = dlp_v2.DlpServiceClient() 116 117 self._enabled = config.enabled 118 119 logger.warning( 120 "GCPSecurityClient is a STUB implementation. " 121 "detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。" 122 ) 123 logger.info( 124 f"GCPSecurityClient initialized - project={config.project_id}" 125 ) 126 127 except Exception as e: 128 logger.error(f"Failed to initialize GCP DLP client: {e}") 129 raise SecurityConfigError(f"Failed to connect to GCP DLP: {e}") from e 130 131 @property 132 def provider(self) -> SecurityProvider: 133 return SecurityProvider.GCP 134 135 async def check_content_moderation( 136 self, 137 text: str, 138 threshold: Optional[int] = None, 139 ) -> ContentModerationResult: 140 """ 141 有害コンテンツをチェック 142 143 Note: 144 GCP DLPにはContent Moderation機能がないため、 145 常にblocked=Falseを返します。 146 ※ コンテンツモデレーションにはPerspective APIを検討 147 """ 148 # GCP DLPには直接的なContent Moderation機能がない 149 # Perspective API (Google Jigsaw) または Vertex AI Safety Filtersが必要 150 return ContentModerationResult( 151 blocked=False, 152 error="GCP DLP does not support content moderation. Consider Perspective API.", 153 error_code="NOT_SUPPORTED", 154 ) 155 156 async def check_prompt_shield( 157 self, 158 user_prompt: str, 159 documents: Optional[List[str]] = None, 160 ) -> PromptShieldResult: 161 """ 162 Jailbreak攻撃を検出 163 164 Note: 165 GCP DLPにはPrompt Shield相当の機能がないため、 166 常にattack_detected=Falseを返します。 167 """ 168 # GCP DLPには直接的なPrompt Shield機能がない 169 return PromptShieldResult( 170 attack_detected=False, 171 error="GCP DLP does not support Prompt Shield functionality", 172 error_code="NOT_SUPPORTED", 173 ) 174 175 async def detect_pii( 176 self, 177 text: str, 178 mask: bool = True, 179 ) -> PIIDetectionResult: 180 """ 181 PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content) 182 183 Note: 184 スタブ実装 - NOT_IMPLEMENTEDを返します 185 186 TODO: 187 - inspect_content APIの実装 188 - deidentify_content APIの実装 189 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, 190 STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, 191 JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, 192 JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等 193 """ 194 # TODO: 本格実装 195 # parent = f"projects/{self._gcp_config.project_id}" 196 # 197 # inspect_config = { 198 # "info_types": [ 199 # {"name": "CREDIT_CARD_NUMBER"}, 200 # {"name": "EMAIL_ADDRESS"}, 201 # {"name": "PHONE_NUMBER"}, 202 # {"name": "PERSON_NAME"}, 203 # ... 204 # ], 205 # "min_likelihood": "POSSIBLE", # VERY_UNLIKELY, UNLIKELY, POSSIBLE, LIKELY, VERY_LIKELY 206 # } 207 # 208 # item = {"value": text} 209 # 210 # # asyncioブロック防止 211 # response = await asyncio.get_event_loop().run_in_executor( 212 # None, 213 # lambda: self._dlp_client.inspect_content( 214 # request={"parent": parent, "inspect_config": inspect_config, "item": item} 215 # ) 216 # ) 217 # 218 # # deidentify_content for masking 219 # deidentify_config = { 220 # "info_type_transformations": { 221 # "transformations": [ 222 # { 223 # "primitive_transformation": { 224 # "replace_config": {"new_value": {"string_value": self._gcp_config.pii_mask_string}} 225 # } 226 # } 227 # ] 228 # } 229 # } 230 return PIIDetectionResult( 231 success=False, 232 masked_text="", 233 error="GCP DLP PII detection not yet implemented", 234 error_code="NOT_IMPLEMENTED", 235 ) 236 237 async def close(self) -> None: 238 """クライアントリソースをクローズ 239 240 Note: 241 gRPCクライアントのクローズ時に例外が発生しても 242 リソースを確実にクリーンアップします。 243 """ 244 try: 245 # gRPCクライアントのクローズ 246 if self._dlp_client: 247 # DlpServiceClientにはclose()メソッドがある場合がある 248 if hasattr(self._dlp_client, 'transport') and hasattr(self._dlp_client.transport, 'close'): 249 self._dlp_client.transport.close() 250 except Exception as e: 251 logger.warning(f"Error closing DLP client: {e}") 252 finally: 253 self._dlp_client = None 254 self._enabled = False 255 logger.debug("GCPSecurityClient closed")
Google Cloud DLP セキュリティクライアント(スタブ実装)
Note: 現在はスタブ実装です。マーケットプレイス向けに本格実装予定。 detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。
Example: config = GCPSecurityConfig( project_id="my-project", credentials_path="/path/to/service-account.json", ) client = GCPSecurityClient(config) result = await client.detect_pii("My email is test@example.com")
75 def __init__(self, config: GCPSecurityConfig): 76 """ 77 Initialize Google Cloud DLP Security Client 78 79 Args: 80 config: GCPSecurityConfig instance 81 82 Raises: 83 SecurityConfigError: 設定が不正な場合 84 """ 85 super().__init__(config) 86 87 if not GCP_DLP_AVAILABLE: 88 raise SecurityConfigError( 89 "google-cloud-dlp package is not installed. " 90 "Install with: pip install google-cloud-dlp" 91 ) 92 93 if not config.project_id: 94 raise SecurityConfigError("GCP project_id is required") 95 96 self._gcp_config: GCPSecurityConfig = config 97 self._dlp_client = None 98 99 try: 100 # 認証情報の設定 101 if config.credentials_path: 102 credentials = service_account.Credentials.from_service_account_file( 103 config.credentials_path 104 ) 105 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 106 elif config.credentials_json: 107 import json 108 credentials_info = json.loads(config.credentials_json) 109 credentials = service_account.Credentials.from_service_account_info( 110 credentials_info 111 ) 112 self._dlp_client = dlp_v2.DlpServiceClient(credentials=credentials) 113 else: 114 # デフォルト認証(環境変数 GOOGLE_APPLICATION_CREDENTIALS) 115 self._dlp_client = dlp_v2.DlpServiceClient() 116 117 self._enabled = config.enabled 118 119 logger.warning( 120 "GCPSecurityClient is a STUB implementation. " 121 "detect_pii等のメソッドはNOT_IMPLEMENTEDを返します。" 122 ) 123 logger.info( 124 f"GCPSecurityClient initialized - project={config.project_id}" 125 ) 126 127 except Exception as e: 128 logger.error(f"Failed to initialize GCP DLP client: {e}") 129 raise SecurityConfigError(f"Failed to connect to GCP DLP: {e}") from e
Initialize Google Cloud DLP Security Client
Args: config: GCPSecurityConfig instance
Raises: SecurityConfigError: 設定が不正な場合
135 async def check_content_moderation( 136 self, 137 text: str, 138 threshold: Optional[int] = None, 139 ) -> ContentModerationResult: 140 """ 141 有害コンテンツをチェック 142 143 Note: 144 GCP DLPにはContent Moderation機能がないため、 145 常にblocked=Falseを返します。 146 ※ コンテンツモデレーションにはPerspective APIを検討 147 """ 148 # GCP DLPには直接的なContent Moderation機能がない 149 # Perspective API (Google Jigsaw) または Vertex AI Safety Filtersが必要 150 return ContentModerationResult( 151 blocked=False, 152 error="GCP DLP does not support content moderation. Consider Perspective API.", 153 error_code="NOT_SUPPORTED", 154 )
有害コンテンツをチェック
Note: GCP DLPにはContent Moderation機能がないため、 常にblocked=Falseを返します。 ※ コンテンツモデレーションにはPerspective APIを検討
156 async def check_prompt_shield( 157 self, 158 user_prompt: str, 159 documents: Optional[List[str]] = None, 160 ) -> PromptShieldResult: 161 """ 162 Jailbreak攻撃を検出 163 164 Note: 165 GCP DLPにはPrompt Shield相当の機能がないため、 166 常にattack_detected=Falseを返します。 167 """ 168 # GCP DLPには直接的なPrompt Shield機能がない 169 return PromptShieldResult( 170 attack_detected=False, 171 error="GCP DLP does not support Prompt Shield functionality", 172 error_code="NOT_SUPPORTED", 173 )
Jailbreak攻撃を検出
Note: GCP DLPにはPrompt Shield相当の機能がないため、 常にattack_detected=Falseを返します。
175 async def detect_pii( 176 self, 177 text: str, 178 mask: bool = True, 179 ) -> PIIDetectionResult: 180 """ 181 PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content) 182 183 Note: 184 スタブ実装 - NOT_IMPLEMENTEDを返します 185 186 TODO: 187 - inspect_content APIの実装 188 - deidentify_content APIの実装 189 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, 190 STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, 191 JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, 192 JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等 193 """ 194 # TODO: 本格実装 195 # parent = f"projects/{self._gcp_config.project_id}" 196 # 197 # inspect_config = { 198 # "info_types": [ 199 # {"name": "CREDIT_CARD_NUMBER"}, 200 # {"name": "EMAIL_ADDRESS"}, 201 # {"name": "PHONE_NUMBER"}, 202 # {"name": "PERSON_NAME"}, 203 # ... 204 # ], 205 # "min_likelihood": "POSSIBLE", # VERY_UNLIKELY, UNLIKELY, POSSIBLE, LIKELY, VERY_LIKELY 206 # } 207 # 208 # item = {"value": text} 209 # 210 # # asyncioブロック防止 211 # response = await asyncio.get_event_loop().run_in_executor( 212 # None, 213 # lambda: self._dlp_client.inspect_content( 214 # request={"parent": parent, "inspect_config": inspect_config, "item": item} 215 # ) 216 # ) 217 # 218 # # deidentify_content for masking 219 # deidentify_config = { 220 # "info_type_transformations": { 221 # "transformations": [ 222 # { 223 # "primitive_transformation": { 224 # "replace_config": {"new_value": {"string_value": self._gcp_config.pii_mask_string}} 225 # } 226 # } 227 # ] 228 # } 229 # } 230 return PIIDetectionResult( 231 success=False, 232 masked_text="", 233 error="GCP DLP PII detection not yet implemented", 234 error_code="NOT_IMPLEMENTED", 235 )
PIIを検出・マスク(Google Cloud DLP inspect_content / deidentify_content)
Note: スタブ実装 - NOT_IMPLEMENTEDを返します
TODO: - inspect_content APIの実装 - deidentify_content APIの実装 - サポートInfoType: CREDIT_CARD_NUMBER, EMAIL_ADDRESS, PHONE_NUMBER, STREET_ADDRESS, PERSON_NAME, DATE_OF_BIRTH, PASSPORT, JAPAN_BANK_ACCOUNT, JAPAN_DRIVERS_LICENSE_NUMBER, JAPAN_INDIVIDUAL_NUMBER, JAPAN_PASSPORT等
237 async def close(self) -> None: 238 """クライアントリソースをクローズ 239 240 Note: 241 gRPCクライアントのクローズ時に例外が発生しても 242 リソースを確実にクリーンアップします。 243 """ 244 try: 245 # gRPCクライアントのクローズ 246 if self._dlp_client: 247 # DlpServiceClientにはclose()メソッドがある場合がある 248 if hasattr(self._dlp_client, 'transport') and hasattr(self._dlp_client.transport, 'close'): 249 self._dlp_client.transport.close() 250 except Exception as e: 251 logger.warning(f"Error closing DLP client: {e}") 252 finally: 253 self._dlp_client = None 254 self._enabled = False 255 logger.debug("GCPSecurityClient closed")
クライアントリソースをクローズ
Note: gRPCクライアントのクローズ時に例外が発生しても リソースを確実にクリーンアップします。
62class AgenticStarAuthClient: 63 """ 64 AgenticStar Auth API クライアント 65 66 認証API(ユーザー管理、MCP OAuthトークン)へのアクセスを提供。 67 X-Internal-API-Key認証を使用します。 68 69 Attributes: 70 config: AgenticStarAuthConfig設定 71 72 Example: 73 ```python 74 config = AgenticStarAuthConfig.create( 75 base_url="https://auth.example.com", 76 api_key="xxx", 77 ) 78 async with AgenticStarAuthClient(config) as client: 79 users = await client.get_users() 80 ``` 81 """ 82 83 def __init__(self, config: AgenticStarAuthConfig): 84 """ 85 AgenticStar Auth Clientを初期化 86 87 Args: 88 config: AgenticStarAuthConfig インスタンス 89 90 Raises: 91 AuthConfigError: 設定が不正な場合 92 """ 93 if not config.base_url: 94 raise AuthConfigError("base_url is required") 95 96 if not config.api_key: 97 raise AuthConfigError("api_key is required") 98 99 self._config = config 100 self._http_client: Optional[httpx.AsyncClient] = None 101 102 auth_type = getattr(config, 'auth_type', 'api_key') 103 logger.info( 104 f"AgenticStarAuthClient initialized - base_url={config.base_url[:30]}..., auth_type={auth_type}" 105 ) 106 107 async def __aenter__(self) -> "AgenticStarAuthClient": 108 """Context Manager: 開始時の処理""" 109 return self 110 111 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 112 """Context Manager: 終了時にリソースをクローズ""" 113 await self.close() 114 115 def _get_http_client(self) -> httpx.AsyncClient: 116 """HTTP clientを取得(遅延初期化)""" 117 if self._http_client is None or self._http_client.is_closed: 118 self._http_client = httpx.AsyncClient(timeout=self._config.timeout) 119 return self._http_client 120 121 def _get_headers(self) -> Dict[str, str]: 122 """共通ヘッダーを取得 123 124 auth_typeに応じてヘッダーを切り替え: 125 - "api_key" (Pod mode): X-Internal-API-Key ヘッダー 126 - "bearer" (CLI mode): Authorization: Bearer ヘッダー 127 """ 128 headers = {"Content-Type": "application/json"} 129 130 auth_type = getattr(self._config, 'auth_type', 'api_key') 131 if auth_type == "bearer": 132 # CLIモード: OAuth Bearer token 133 headers["Authorization"] = f"Bearer {self._config.api_key}" 134 else: 135 # Podモード: Internal API Key 136 headers["X-Internal-API-Key"] = self._config.api_key 137 138 return headers 139 140 async def _request( 141 self, 142 method: str, 143 endpoint: str, 144 params: Optional[Dict[str, Any]] = None, 145 json_body: Optional[Dict[str, Any]] = None, 146 ) -> Dict[str, Any]: 147 """HTTPリクエスト実行 148 149 Args: 150 method: HTTPメソッド (GET, POST, etc.) 151 endpoint: APIエンドポイント(baseURLからの相対パス) 152 params: クエリパラメータ 153 json_body: JSONリクエストボディ 154 155 Returns: 156 JSONレスポンス 157 158 Raises: 159 AuthUnauthorizedError: 認証失敗 (401) 160 AuthNotFoundError: リソース未発見 (404) 161 AuthRateLimitError: レート制限 (429) 162 AuthAPIError: その他のAPIエラー 163 """ 164 url = f"{self._config.base_url.rstrip('/')}{endpoint}" 165 client = self._get_http_client() 166 headers = self._get_headers() 167 168 try: 169 if method.upper() == "GET": 170 response = await client.get(url, headers=headers, params=params) 171 elif method.upper() == "POST": 172 response = await client.post(url, headers=headers, json=json_body) 173 else: 174 raise AuthAPIError(f"Unsupported HTTP method: {method}") 175 176 # ステータスコード別処理 177 if response.status_code == 200: 178 return response.json() 179 elif response.status_code == 401: 180 raise AuthUnauthorizedError() 181 elif response.status_code == 404: 182 raise AuthNotFoundError() 183 elif response.status_code == 429: 184 raise AuthRateLimitError() 185 else: 186 error_msg = f"Status {response.status_code}: {response.text}" 187 raise AuthAPIError(error_msg, response.status_code) 188 189 except httpx.RequestError as e: 190 raise AuthAPIError(f"Request failed: {e}") 191 192 # ============================================================ 193 # ユーザー一覧取得 194 # ============================================================ 195 196 async def get_users( 197 self, 198 page: int = 1, 199 limit: int = 20, 200 search: Optional[str] = None, 201 is_approved: Optional[str] = None, 202 job: Optional[str] = None, 203 organization: Optional[str] = None, 204 sort_by: Optional[str] = None, 205 order: Optional[str] = None, 206 include_last_login: bool = False, 207 ) -> GetUsersResult: 208 """ 209 ユーザー一覧を取得 210 211 Args: 212 page: ページ番号(デフォルト: 1) 213 limit: 1ページあたりの件数(デフォルト: 20) 214 search: 検索キーワード(email, username, displayNameで検索) 215 is_approved: 承認状態フィルター ('true', 'false', 'all') 216 job: 職種でフィルター 217 organization: 組織でフィルター 218 sort_by: ソート項目 ('created_timestamp', 'email', 'username') 219 order: ソート順 ('asc', 'desc') 220 include_last_login: 最終ログイン時刻を取得するか 221 222 Returns: 223 GetUsersResult: ユーザー一覧結果 224 225 Example: 226 ```python 227 result = await client.get_users( 228 page=1, 229 limit=20, 230 is_approved="true", 231 sort_by="created_timestamp", 232 order="desc", 233 ) 234 if result.success: 235 for user in result.users: 236 print(f"User: {user.email}") 237 ``` 238 """ 239 try: 240 params: Dict[str, Any] = { 241 "page": page, 242 "limit": limit, 243 } 244 245 if search: 246 params["search"] = search 247 if is_approved: 248 params["is_approved"] = is_approved 249 if job: 250 params["job"] = job 251 if organization: 252 params["organization"] = organization 253 if sort_by: 254 params["sort_by"] = sort_by 255 if order: 256 params["order"] = order 257 if include_last_login: 258 params["include_last_login"] = "true" 259 260 response = await self._request( 261 "GET", "/api/v1/admin/users", params=params 262 ) 263 264 # レスポンス解析 265 data = response.get("data", {}) 266 items_raw = data.get("items", []) 267 pagination_raw = data.get("pagination", {}) 268 269 users = [ApiUser.model_validate(item) for item in items_raw] 270 pagination = ( 271 UserPagination.model_validate(pagination_raw) 272 if pagination_raw 273 else None 274 ) 275 276 return GetUsersResult( 277 success=True, 278 users=users, 279 pagination=pagination, 280 ) 281 282 except (AuthUnauthorizedError, AuthNotFoundError, AuthRateLimitError, AuthAPIError) as e: 283 logger.error(f"Failed to get users: {e}") 284 return GetUsersResult( 285 success=False, 286 error=str(e), 287 error_code=getattr(e, "error_code", "API_ERROR"), 288 ) 289 except Exception as e: 290 logger.error(f"Unexpected error getting users: {e}") 291 return GetUsersResult( 292 success=False, 293 error=str(e), 294 error_code="UNKNOWN_ERROR", 295 ) 296 297 # ============================================================ 298 # ユーザー詳細取得 299 # ============================================================ 300 301 async def get_user(self, user_id: str) -> GetUserResult: 302 """ 303 ユーザー詳細を取得 304 305 Args: 306 user_id: ユーザーID 307 308 Returns: 309 GetUserResult: ユーザー詳細結果 310 311 Example: 312 ```python 313 result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4") 314 if result.success: 315 print(f"User: {result.user.email}") 316 if result.devices: 317 print(f"Devices: {len(result.devices)}") 318 ``` 319 """ 320 if not user_id: 321 return GetUserResult( 322 success=False, 323 error="user_id is required", 324 error_code="INVALID_PARAMETER", 325 ) 326 327 try: 328 response = await self._request( 329 "GET", f"/api/v1/admin/users/{user_id}" 330 ) 331 332 user_raw = response.get("user", {}) 333 devices_raw = response.get("devices", []) 334 login_history_raw = response.get("loginHistory", []) 335 336 user = ApiUser.model_validate(user_raw) 337 # 空配列はそのまま空配列として返す(Noneに変換しない) 338 devices = [DeviceInfo.model_validate(d) for d in devices_raw] 339 login_history = [LoginHistoryEntry.model_validate(h) for h in login_history_raw] 340 341 # ソフトデリート情報 342 is_deleted = user_raw.get("isDeleted", False) 343 deleted_at = user_raw.get("deletedAt") 344 deleted_by = user_raw.get("deletedBy") 345 deletion_reason = user_raw.get("deletionReason") 346 347 return GetUserResult( 348 success=True, 349 user=user, 350 devices=devices, 351 login_history=login_history, 352 is_deleted=is_deleted, 353 deleted_at=deleted_at, 354 deleted_by=deleted_by, 355 deletion_reason=deletion_reason, 356 ) 357 358 except AuthNotFoundError: 359 return GetUserResult( 360 success=False, 361 error=f"User not found: {user_id}", 362 error_code="NOT_FOUND", 363 ) 364 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 365 logger.error(f"Failed to get user {user_id}: {e}") 366 return GetUserResult( 367 success=False, 368 error=str(e), 369 error_code=getattr(e, "error_code", "API_ERROR"), 370 ) 371 except Exception as e: 372 logger.error(f"Unexpected error getting user {user_id}: {e}") 373 return GetUserResult( 374 success=False, 375 error=str(e), 376 error_code="UNKNOWN_ERROR", 377 ) 378 379 # ============================================================ 380 # MCP OAuth トークン取得 381 # ============================================================ 382 383 async def get_mcp_tokens( 384 self, 385 user_id: str, 386 providers: Optional[List[OAuthProviderName]] = None, 387 ) -> GetMCPTokensResult: 388 """ 389 MCP用OAuthトークンを取得 390 391 Args: 392 user_id: ユーザーID 393 providers: 取得するプロバイダー(省略時は全プロバイダー) 394 例: ["github", "slack", "google", "office365"] 395 396 Returns: 397 GetMCPTokensResult: MCPトークン取得結果 398 399 Example: 400 ```python 401 result = await client.get_mcp_tokens( 402 "user-id-123", 403 providers=["github", "slack"], 404 ) 405 if result.success: 406 for provider, token in result.tokens.items(): 407 print(f"{provider}: expires at {token.expires_at}") 408 if result.errors: 409 for provider, err in result.errors.items(): 410 print(f"{provider} error: {err.message}") 411 ``` 412 """ 413 if not user_id: 414 return GetMCPTokensResult( 415 success=False, 416 error="user_id is required", 417 error_code="INVALID_PARAMETER", 418 ) 419 420 try: 421 body: Dict[str, Any] = {"user_id": user_id} 422 if providers: 423 body["providers"] = providers 424 425 response = await self._request( 426 "POST", "/api/v1/oauth/mcp/token", json_body=body 427 ) 428 429 tokens_raw = response.get("tokens", {}) 430 errors_raw = response.get("errors", {}) 431 432 # トークン変換(expires_at: string -> datetime) 433 tokens: Dict[str, MCPTokenInfo] = {} 434 token_errors: Dict[str, MCPTokenError] = {} 435 436 for provider, token_data in tokens_raw.items(): 437 expires_at_str = token_data.get("expires_at") 438 439 # expires_atが欠損または不正な場合はエラーとして扱う 440 if not expires_at_str: 441 token_errors[provider] = MCPTokenError( 442 code="INVALID_TOKEN_DATA", 443 message="expires_at is missing", 444 ) 445 continue 446 447 try: 448 expires_at = datetime.fromisoformat( 449 expires_at_str.replace("Z", "+00:00") 450 ) 451 except (ValueError, TypeError) as e: 452 token_errors[provider] = MCPTokenError( 453 code="INVALID_EXPIRES_AT", 454 message=f"Invalid expires_at format: {expires_at_str}", 455 ) 456 continue 457 458 tokens[provider] = MCPTokenInfo( 459 access_token=token_data.get("access_token", ""), 460 token_type=token_data.get("token_type", "Bearer"), 461 expires_at=expires_at, 462 scopes=token_data.get("scopes", []), 463 ) 464 465 # APIからのエラーを追加 466 for provider, err in errors_raw.items(): 467 token_errors[provider] = MCPTokenError( 468 code=err.get("code", "UNKNOWN"), 469 message=err.get("message", "Unknown error"), 470 ) 471 472 # エラーがあればerrorsに設定 473 final_errors = token_errors if token_errors else None 474 475 return GetMCPTokensResult( 476 success=True, 477 tokens=tokens, 478 errors=final_errors, 479 ) 480 481 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 482 logger.error(f"Failed to get MCP tokens for {user_id}: {e}") 483 return GetMCPTokensResult( 484 success=False, 485 error=str(e), 486 error_code=getattr(e, "error_code", "API_ERROR"), 487 ) 488 except Exception as e: 489 logger.error(f"Unexpected error getting MCP tokens for {user_id}: {e}") 490 return GetMCPTokensResult( 491 success=False, 492 error=str(e), 493 error_code="UNKNOWN_ERROR", 494 ) 495 496 # ============================================================ 497 # リソースクローズ 498 # ============================================================ 499 500 async def close(self) -> None: 501 """クライアントリソースをクローズ""" 502 try: 503 if self._http_client and not self._http_client.is_closed: 504 await self._http_client.aclose() 505 except Exception as e: 506 logger.warning(f"Error closing HTTP client: {e}") 507 finally: 508 self._http_client = None 509 logger.debug("AgenticStarAuthClient closed")
AgenticStar Auth API クライアント
認証API(ユーザー管理、MCP OAuthトークン)へのアクセスを提供。 X-Internal-API-Key認証を使用します。
Attributes: config: AgenticStarAuthConfig設定
Example:
config = AgenticStarAuthConfig.create(
base_url="https://auth.example.com",
api_key="xxx",
)
async with AgenticStarAuthClient(config) as client:
users = await client.get_users()
83 def __init__(self, config: AgenticStarAuthConfig): 84 """ 85 AgenticStar Auth Clientを初期化 86 87 Args: 88 config: AgenticStarAuthConfig インスタンス 89 90 Raises: 91 AuthConfigError: 設定が不正な場合 92 """ 93 if not config.base_url: 94 raise AuthConfigError("base_url is required") 95 96 if not config.api_key: 97 raise AuthConfigError("api_key is required") 98 99 self._config = config 100 self._http_client: Optional[httpx.AsyncClient] = None 101 102 auth_type = getattr(config, 'auth_type', 'api_key') 103 logger.info( 104 f"AgenticStarAuthClient initialized - base_url={config.base_url[:30]}..., auth_type={auth_type}" 105 )
AgenticStar Auth Clientを初期化
Args: config: AgenticStarAuthConfig インスタンス
Raises: AuthConfigError: 設定が不正な場合
196 async def get_users( 197 self, 198 page: int = 1, 199 limit: int = 20, 200 search: Optional[str] = None, 201 is_approved: Optional[str] = None, 202 job: Optional[str] = None, 203 organization: Optional[str] = None, 204 sort_by: Optional[str] = None, 205 order: Optional[str] = None, 206 include_last_login: bool = False, 207 ) -> GetUsersResult: 208 """ 209 ユーザー一覧を取得 210 211 Args: 212 page: ページ番号(デフォルト: 1) 213 limit: 1ページあたりの件数(デフォルト: 20) 214 search: 検索キーワード(email, username, displayNameで検索) 215 is_approved: 承認状態フィルター ('true', 'false', 'all') 216 job: 職種でフィルター 217 organization: 組織でフィルター 218 sort_by: ソート項目 ('created_timestamp', 'email', 'username') 219 order: ソート順 ('asc', 'desc') 220 include_last_login: 最終ログイン時刻を取得するか 221 222 Returns: 223 GetUsersResult: ユーザー一覧結果 224 225 Example: 226 ```python 227 result = await client.get_users( 228 page=1, 229 limit=20, 230 is_approved="true", 231 sort_by="created_timestamp", 232 order="desc", 233 ) 234 if result.success: 235 for user in result.users: 236 print(f"User: {user.email}") 237 ``` 238 """ 239 try: 240 params: Dict[str, Any] = { 241 "page": page, 242 "limit": limit, 243 } 244 245 if search: 246 params["search"] = search 247 if is_approved: 248 params["is_approved"] = is_approved 249 if job: 250 params["job"] = job 251 if organization: 252 params["organization"] = organization 253 if sort_by: 254 params["sort_by"] = sort_by 255 if order: 256 params["order"] = order 257 if include_last_login: 258 params["include_last_login"] = "true" 259 260 response = await self._request( 261 "GET", "/api/v1/admin/users", params=params 262 ) 263 264 # レスポンス解析 265 data = response.get("data", {}) 266 items_raw = data.get("items", []) 267 pagination_raw = data.get("pagination", {}) 268 269 users = [ApiUser.model_validate(item) for item in items_raw] 270 pagination = ( 271 UserPagination.model_validate(pagination_raw) 272 if pagination_raw 273 else None 274 ) 275 276 return GetUsersResult( 277 success=True, 278 users=users, 279 pagination=pagination, 280 ) 281 282 except (AuthUnauthorizedError, AuthNotFoundError, AuthRateLimitError, AuthAPIError) as e: 283 logger.error(f"Failed to get users: {e}") 284 return GetUsersResult( 285 success=False, 286 error=str(e), 287 error_code=getattr(e, "error_code", "API_ERROR"), 288 ) 289 except Exception as e: 290 logger.error(f"Unexpected error getting users: {e}") 291 return GetUsersResult( 292 success=False, 293 error=str(e), 294 error_code="UNKNOWN_ERROR", 295 )
ユーザー一覧を取得
Args: page: ページ番号(デフォルト: 1) limit: 1ページあたりの件数(デフォルト: 20) search: 検索キーワード(email, username, displayNameで検索) is_approved: 承認状態フィルター ('true', 'false', 'all') job: 職種でフィルター organization: 組織でフィルター sort_by: ソート項目 ('created_timestamp', 'email', 'username') order: ソート順 ('asc', 'desc') include_last_login: 最終ログイン時刻を取得するか
Returns: GetUsersResult: ユーザー一覧結果
Example:
result = await client.get_users(
page=1,
limit=20,
is_approved="true",
sort_by="created_timestamp",
order="desc",
)
if result.success:
for user in result.users:
print(f"User: {user.email}")
301 async def get_user(self, user_id: str) -> GetUserResult: 302 """ 303 ユーザー詳細を取得 304 305 Args: 306 user_id: ユーザーID 307 308 Returns: 309 GetUserResult: ユーザー詳細結果 310 311 Example: 312 ```python 313 result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4") 314 if result.success: 315 print(f"User: {result.user.email}") 316 if result.devices: 317 print(f"Devices: {len(result.devices)}") 318 ``` 319 """ 320 if not user_id: 321 return GetUserResult( 322 success=False, 323 error="user_id is required", 324 error_code="INVALID_PARAMETER", 325 ) 326 327 try: 328 response = await self._request( 329 "GET", f"/api/v1/admin/users/{user_id}" 330 ) 331 332 user_raw = response.get("user", {}) 333 devices_raw = response.get("devices", []) 334 login_history_raw = response.get("loginHistory", []) 335 336 user = ApiUser.model_validate(user_raw) 337 # 空配列はそのまま空配列として返す(Noneに変換しない) 338 devices = [DeviceInfo.model_validate(d) for d in devices_raw] 339 login_history = [LoginHistoryEntry.model_validate(h) for h in login_history_raw] 340 341 # ソフトデリート情報 342 is_deleted = user_raw.get("isDeleted", False) 343 deleted_at = user_raw.get("deletedAt") 344 deleted_by = user_raw.get("deletedBy") 345 deletion_reason = user_raw.get("deletionReason") 346 347 return GetUserResult( 348 success=True, 349 user=user, 350 devices=devices, 351 login_history=login_history, 352 is_deleted=is_deleted, 353 deleted_at=deleted_at, 354 deleted_by=deleted_by, 355 deletion_reason=deletion_reason, 356 ) 357 358 except AuthNotFoundError: 359 return GetUserResult( 360 success=False, 361 error=f"User not found: {user_id}", 362 error_code="NOT_FOUND", 363 ) 364 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 365 logger.error(f"Failed to get user {user_id}: {e}") 366 return GetUserResult( 367 success=False, 368 error=str(e), 369 error_code=getattr(e, "error_code", "API_ERROR"), 370 ) 371 except Exception as e: 372 logger.error(f"Unexpected error getting user {user_id}: {e}") 373 return GetUserResult( 374 success=False, 375 error=str(e), 376 error_code="UNKNOWN_ERROR", 377 )
ユーザー詳細を取得
Args: user_id: ユーザーID
Returns: GetUserResult: ユーザー詳細結果
Example:
result = await client.get_user("c45005f6-b715-47bb-a3d7-44ab208cecc4")
if result.success:
print(f"User: {result.user.email}")
if result.devices:
print(f"Devices: {len(result.devices)}")
383 async def get_mcp_tokens( 384 self, 385 user_id: str, 386 providers: Optional[List[OAuthProviderName]] = None, 387 ) -> GetMCPTokensResult: 388 """ 389 MCP用OAuthトークンを取得 390 391 Args: 392 user_id: ユーザーID 393 providers: 取得するプロバイダー(省略時は全プロバイダー) 394 例: ["github", "slack", "google", "office365"] 395 396 Returns: 397 GetMCPTokensResult: MCPトークン取得結果 398 399 Example: 400 ```python 401 result = await client.get_mcp_tokens( 402 "user-id-123", 403 providers=["github", "slack"], 404 ) 405 if result.success: 406 for provider, token in result.tokens.items(): 407 print(f"{provider}: expires at {token.expires_at}") 408 if result.errors: 409 for provider, err in result.errors.items(): 410 print(f"{provider} error: {err.message}") 411 ``` 412 """ 413 if not user_id: 414 return GetMCPTokensResult( 415 success=False, 416 error="user_id is required", 417 error_code="INVALID_PARAMETER", 418 ) 419 420 try: 421 body: Dict[str, Any] = {"user_id": user_id} 422 if providers: 423 body["providers"] = providers 424 425 response = await self._request( 426 "POST", "/api/v1/oauth/mcp/token", json_body=body 427 ) 428 429 tokens_raw = response.get("tokens", {}) 430 errors_raw = response.get("errors", {}) 431 432 # トークン変換(expires_at: string -> datetime) 433 tokens: Dict[str, MCPTokenInfo] = {} 434 token_errors: Dict[str, MCPTokenError] = {} 435 436 for provider, token_data in tokens_raw.items(): 437 expires_at_str = token_data.get("expires_at") 438 439 # expires_atが欠損または不正な場合はエラーとして扱う 440 if not expires_at_str: 441 token_errors[provider] = MCPTokenError( 442 code="INVALID_TOKEN_DATA", 443 message="expires_at is missing", 444 ) 445 continue 446 447 try: 448 expires_at = datetime.fromisoformat( 449 expires_at_str.replace("Z", "+00:00") 450 ) 451 except (ValueError, TypeError) as e: 452 token_errors[provider] = MCPTokenError( 453 code="INVALID_EXPIRES_AT", 454 message=f"Invalid expires_at format: {expires_at_str}", 455 ) 456 continue 457 458 tokens[provider] = MCPTokenInfo( 459 access_token=token_data.get("access_token", ""), 460 token_type=token_data.get("token_type", "Bearer"), 461 expires_at=expires_at, 462 scopes=token_data.get("scopes", []), 463 ) 464 465 # APIからのエラーを追加 466 for provider, err in errors_raw.items(): 467 token_errors[provider] = MCPTokenError( 468 code=err.get("code", "UNKNOWN"), 469 message=err.get("message", "Unknown error"), 470 ) 471 472 # エラーがあればerrorsに設定 473 final_errors = token_errors if token_errors else None 474 475 return GetMCPTokensResult( 476 success=True, 477 tokens=tokens, 478 errors=final_errors, 479 ) 480 481 except (AuthUnauthorizedError, AuthRateLimitError, AuthAPIError) as e: 482 logger.error(f"Failed to get MCP tokens for {user_id}: {e}") 483 return GetMCPTokensResult( 484 success=False, 485 error=str(e), 486 error_code=getattr(e, "error_code", "API_ERROR"), 487 ) 488 except Exception as e: 489 logger.error(f"Unexpected error getting MCP tokens for {user_id}: {e}") 490 return GetMCPTokensResult( 491 success=False, 492 error=str(e), 493 error_code="UNKNOWN_ERROR", 494 )
MCP用OAuthトークンを取得
Args: user_id: ユーザーID providers: 取得するプロバイダー(省略時は全プロバイダー) 例: ["github", "slack", "google", "office365"]
Returns: GetMCPTokensResult: MCPトークン取得結果
Example:
result = await client.get_mcp_tokens(
"user-id-123",
providers=["github", "slack"],
)
if result.success:
for provider, token in result.tokens.items():
print(f"{provider}: expires at {token.expires_at}")
if result.errors:
for provider, err in result.errors.items():
print(f"{provider} error: {err.message}")
500 async def close(self) -> None: 501 """クライアントリソースをクローズ""" 502 try: 503 if self._http_client and not self._http_client.is_closed: 504 await self._http_client.aclose() 505 except Exception as e: 506 logger.warning(f"Error closing HTTP client: {e}") 507 finally: 508 self._http_client = None 509 logger.debug("AgenticStarAuthClient closed")
クライアントリソースをクローズ
21@dataclass 22class AgenticStarAuthConfig: 23 """AgenticStar Auth API設定 24 25 Args: 26 base_url: APIベースURL (例: "https://auth.example.com") 27 api_key: X-Internal-API-Key認証用APIキー、またはBearer token 28 auth_type: 認証タイプ ("api_key" or "bearer") 29 timeout: HTTPリクエストタイムアウト秒数 30 31 Note: 32 api_keyはrepr=Falseでログ出力から除外されます。 33 CLIモードでは create() で直接設定を注入してください。 34 35 Example: 36 ```python 37 # コードから直接指定 38 config = AgenticStarAuthConfig.create( 39 base_url="https://auth.example.com", 40 api_key="your-api-key", 41 ) 42 43 # config.tomlから読み込み(Podモード) 44 config = AgenticStarAuthConfig.from_config() 45 ``` 46 """ 47 48 # 必須設定 49 base_url: str 50 api_key: str = field(repr=False) # API key or Bearer token 51 52 # オプション設定 53 auth_type: str = "api_key" # "api_key" (Pod) or "bearer" (CLI) 54 timeout: float = 30.0 55 56 def __str__(self) -> str: 57 """文字列表現(APIキーをマスク)""" 58 return f"AgenticStarAuthConfig(base_url={self.base_url!r}, api_key=***)" 59 60 @classmethod 61 def create( 62 cls, 63 base_url: str, 64 api_key: str, 65 auth_type: str = "api_key", 66 timeout: float = 30.0, 67 ) -> "AgenticStarAuthConfig": 68 """コードから直接設定を指定して作成 69 70 Args: 71 base_url: APIベースURL 72 api_key: APIキー(またはBearer token) 73 auth_type: 認証タイプ ("api_key" or "bearer") 74 timeout: タイムアウト秒数 75 76 Returns: 77 AgenticStarAuthConfig インスタンス 78 79 Example: 80 ```python 81 # Pod mode (X-Internal-API-Key) 82 config = AgenticStarAuthConfig.create( 83 base_url="https://auth.example.com", 84 api_key="your-api-key", 85 ) 86 87 # CLI mode (Bearer token) 88 config = AgenticStarAuthConfig.create( 89 base_url="https://auth.example.com", 90 api_key="oauth-token", 91 auth_type="bearer", 92 ) 93 ``` 94 """ 95 return cls( 96 base_url=base_url, 97 api_key=api_key, 98 auth_type=auth_type, 99 timeout=timeout, 100 ) 101 102 @classmethod 103 def from_config( 104 cls, 105 config_path: Optional[Path] = None, 106 ) -> "AgenticStarAuthConfig": 107 """config.tomlから設定を読み込む 108 109 Args: 110 config_path: config.tomlのパス(省略時は自動検出) 111 112 Returns: 113 AgenticStarAuthConfig インスタンス 114 115 Raises: 116 AuthConfigError: 設定ファイルの読み込みに失敗した場合 117 ValueError: 必須設定が見つからない場合 118 """ 119 toml_config = cls._load_toml_config(config_path) 120 121 base_url = toml_config.get("base_url") 122 api_key = toml_config.get("api_key") 123 timeout = toml_config.get("timeout", 30.0) 124 125 # バリデーション 126 if not base_url: 127 raise ValueError( 128 "base_url is required. Configure [auth.agenticstar] base_url in config.toml" 129 ) 130 131 if not api_key: 132 raise ValueError( 133 "api_key is required. Configure [auth.agenticstar] api_key in config.toml" 134 ) 135 136 return cls( 137 base_url=base_url, 138 api_key=api_key, 139 auth_type="api_key", 140 timeout=float(timeout), 141 ) 142 143 @staticmethod 144 def _load_toml_config(config_path: Optional[Path] = None) -> Dict[str, Any]: 145 """config.tomlから[auth.agenticstar]セクションを読み込む 146 147 Raises: 148 AuthConfigError: ファイル読み込みやパースに失敗した場合 149 """ 150 import tomllib 151 152 # パスが指定されていない場合は自動検出 153 candidates = [] 154 if config_path is None: 155 # 優先順位: カレントディレクトリ > プロジェクトルート 156 candidates = [ 157 Path.cwd() / "config.toml", 158 Path(__file__).parent.parent.parent.parent / "config.toml", # sdk/ の親 = プロジェクトルート 159 ] 160 for candidate in candidates: 161 if candidate.exists(): 162 config_path = candidate 163 break 164 165 if config_path is None or not config_path.exists(): 166 raise AuthConfigError( 167 f"config.toml not found. Searched: {[str(c) for c in candidates]}" 168 ) 169 170 try: 171 with open(config_path, "rb") as f: 172 config = tomllib.load(f) 173 except tomllib.TOMLDecodeError as e: 174 raise AuthConfigError(f"Invalid TOML syntax in {config_path}: {e}") 175 except PermissionError as e: 176 raise AuthConfigError(f"Permission denied reading {config_path}: {e}") 177 except Exception as e: 178 raise AuthConfigError(f"Failed to read {config_path}: {e}") 179 180 auth_agenticstar = config.get("auth", {}).get("agenticstar") 181 if auth_agenticstar is None: 182 raise AuthConfigError( 183 f"[auth.agenticstar] section not found in {config_path}" 184 ) 185 186 return auth_agenticstar
AgenticStar Auth API設定
Args: base_url: APIベースURL (例: "https://auth.example.com") api_key: X-Internal-API-Key認証用APIキー、またはBearer token auth_type: 認証タイプ ("api_key" or "bearer") timeout: HTTPリクエストタイムアウト秒数
Note: api_keyはrepr=Falseでログ出力から除外されます。 CLIモードでは create() で直接設定を注入してください。
Example:
# コードから直接指定
config = AgenticStarAuthConfig.create(
base_url="https://auth.example.com",
api_key="your-api-key",
)
# config.tomlから読み込み(Podモード)
config = AgenticStarAuthConfig.from_config()
60 @classmethod 61 def create( 62 cls, 63 base_url: str, 64 api_key: str, 65 auth_type: str = "api_key", 66 timeout: float = 30.0, 67 ) -> "AgenticStarAuthConfig": 68 """コードから直接設定を指定して作成 69 70 Args: 71 base_url: APIベースURL 72 api_key: APIキー(またはBearer token) 73 auth_type: 認証タイプ ("api_key" or "bearer") 74 timeout: タイムアウト秒数 75 76 Returns: 77 AgenticStarAuthConfig インスタンス 78 79 Example: 80 ```python 81 # Pod mode (X-Internal-API-Key) 82 config = AgenticStarAuthConfig.create( 83 base_url="https://auth.example.com", 84 api_key="your-api-key", 85 ) 86 87 # CLI mode (Bearer token) 88 config = AgenticStarAuthConfig.create( 89 base_url="https://auth.example.com", 90 api_key="oauth-token", 91 auth_type="bearer", 92 ) 93 ``` 94 """ 95 return cls( 96 base_url=base_url, 97 api_key=api_key, 98 auth_type=auth_type, 99 timeout=timeout, 100 )
コードから直接設定を指定して作成
Args: base_url: APIベースURL api_key: APIキー(またはBearer token) auth_type: 認証タイプ ("api_key" or "bearer") timeout: タイムアウト秒数
Returns: AgenticStarAuthConfig インスタンス
Example:
# Pod mode (X-Internal-API-Key)
config = AgenticStarAuthConfig.create(
base_url="https://auth.example.com",
api_key="your-api-key",
)
# CLI mode (Bearer token)
config = AgenticStarAuthConfig.create(
base_url="https://auth.example.com",
api_key="oauth-token",
auth_type="bearer",
)
102 @classmethod 103 def from_config( 104 cls, 105 config_path: Optional[Path] = None, 106 ) -> "AgenticStarAuthConfig": 107 """config.tomlから設定を読み込む 108 109 Args: 110 config_path: config.tomlのパス(省略時は自動検出) 111 112 Returns: 113 AgenticStarAuthConfig インスタンス 114 115 Raises: 116 AuthConfigError: 設定ファイルの読み込みに失敗した場合 117 ValueError: 必須設定が見つからない場合 118 """ 119 toml_config = cls._load_toml_config(config_path) 120 121 base_url = toml_config.get("base_url") 122 api_key = toml_config.get("api_key") 123 timeout = toml_config.get("timeout", 30.0) 124 125 # バリデーション 126 if not base_url: 127 raise ValueError( 128 "base_url is required. Configure [auth.agenticstar] base_url in config.toml" 129 ) 130 131 if not api_key: 132 raise ValueError( 133 "api_key is required. Configure [auth.agenticstar] api_key in config.toml" 134 ) 135 136 return cls( 137 base_url=base_url, 138 api_key=api_key, 139 auth_type="api_key", 140 timeout=float(timeout), 141 )
config.tomlから設定を読み込む
Args: config_path: config.tomlのパス(省略時は自動検出)
Returns: AgenticStarAuthConfig インスタンス
Raises: AuthConfigError: 設定ファイルの読み込みに失敗した場合 ValueError: 必須設定が見つからない場合
Auth操作の基底エラー
16class AuthConfigError(AuthError): 17 """Auth設定エラー 18 19 Raises: 20 設定が不正または不足している場合 21 """ 22 23 pass
Auth設定エラー
Raises: 設定が不正または不足している場合
26class AuthAPIError(AuthError): 27 """Auth API呼び出しエラー 28 29 Attributes: 30 status_code: HTTPステータスコード 31 error_code: エラーコード文字列 32 """ 33 34 def __init__( 35 self, 36 message: str, 37 status_code: Optional[int] = None, 38 error_code: str = "API_ERROR", 39 ): 40 super().__init__(message) 41 self.status_code = status_code 42 self.error_code = error_code
Auth API呼び出しエラー
Attributes: status_code: HTTPステータスコード error_code: エラーコード文字列
56class AuthNotFoundError(AuthAPIError): 57 """リソース未発見エラー(404 Not Found) 58 59 Raises: 60 指定されたユーザーまたはリソースが存在しない場合 61 """ 62 63 def __init__(self, message: str = "Resource not found"): 64 super().__init__(message, status_code=404, error_code="NOT_FOUND")
リソース未発見エラー(404 Not Found)
Raises: 指定されたユーザーまたはリソースが存在しない場合
67class AuthRateLimitError(AuthAPIError): 68 """レート制限エラー(429 Too Many Requests) 69 70 Raises: 71 APIレート制限に達した場合 72 """ 73 74 def __init__(self, message: str = "Rate limit exceeded"): 75 super().__init__(message, status_code=429, error_code="RATE_LIMITED")
レート制限エラー(429 Too Many Requests)
Raises: APIレート制限に達した場合
39class ApiUser(BaseModel): 40 """ユーザー情報 41 42 chatboardloginのAPIレスポンス形式に対応: 43 - attributes.organization: { id: string, label: string } 44 - attributes.job: { id: string, label: string } 45 """ 46 47 id: str 48 username: Optional[str] = None 49 email: Optional[str] = None 50 first_name: Optional[str] = Field(None, alias="firstName") 51 last_name: Optional[str] = Field(None, alias="lastName") 52 display_name: Optional[str] = Field(None, alias="displayName") 53 email_verified: bool = Field(False, alias="emailVerified") 54 enabled: bool = True 55 created_timestamp: Optional[int] = Field(None, alias="created_timestamp") 56 attributes: Optional[Dict[str, Any]] = None 57 organization: Optional[str] = None 58 organization_label: Optional[str] = Field(None, alias="organizationLabel") 59 job: Optional[str] = None 60 job_label: Optional[str] = Field(None, alias="jobLabel") 61 bio: Optional[str] = None 62 last_login_at: Optional[str] = Field(None, alias="lastLoginAt") 63 64 model_config = {"populate_by_name": True} 65 66 @model_validator(mode="before") 67 @classmethod 68 def extract_from_attributes(cls, data: Any) -> Any: 69 """attributesからorganization/jobのid/labelを抽出 70 71 chatboardloginのレスポンス形式: 72 { 73 "attributes": { 74 "organization": { "id": "se_bd1", "label": "日本語ラベル" }, 75 "job": { "id": "frontend", "label": "日本語ラベル" }, 76 "bio": "自己紹介" 77 } 78 } 79 """ 80 if not isinstance(data, dict): 81 return data 82 83 attributes = data.get("attributes") 84 if not attributes or not isinstance(attributes, dict): 85 return data 86 87 # organization: { id, label } からid/labelを抽出 88 org_data = attributes.get("organization") 89 if isinstance(org_data, dict): 90 if "id" in org_data and data.get("organization") is None: 91 data["organization"] = org_data["id"] 92 if "label" in org_data and data.get("organization_label") is None: 93 data["organization_label"] = org_data["label"] 94 95 # job: { id, label } からid/labelを抽出 96 job_data = attributes.get("job") 97 if isinstance(job_data, dict): 98 if "id" in job_data and data.get("job") is None: 99 data["job"] = job_data["id"] 100 if "label" in job_data and data.get("job_label") is None: 101 data["job_label"] = job_data["label"] 102 103 # bio: 文字列として直接取得 104 bio_data = attributes.get("bio") 105 if isinstance(bio_data, str) and data.get("bio") is None: 106 data["bio"] = bio_data 107 108 return data 109 110 # ID取得メソッド 111 def get_org_id(self) -> Optional[str]: 112 """組織IDを取得""" 113 return self.organization 114 115 def get_job_id(self) -> Optional[str]: 116 """職種IDを取得""" 117 return self.job 118 119 # ラベル(表示値)取得メソッド 120 def get_display_name(self) -> Optional[str]: 121 """表示名を取得""" 122 return self.display_name 123 124 def get_organization_label(self) -> Optional[str]: 125 """組織ラベルを取得""" 126 return self.organization_label 127 128 def get_job_label(self) -> Optional[str]: 129 """職種ラベルを取得""" 130 return self.job_label 131 132 def get_bio(self) -> Optional[str]: 133 """自己紹介を取得""" 134 return self.bio
ユーザー情報
chatboardloginのAPIレスポンス形式に対応:
- attributes.organization: { id: string, label: string }
- attributes.job: { id: string, label: string }
66 @model_validator(mode="before") 67 @classmethod 68 def extract_from_attributes(cls, data: Any) -> Any: 69 """attributesからorganization/jobのid/labelを抽出 70 71 chatboardloginのレスポンス形式: 72 { 73 "attributes": { 74 "organization": { "id": "se_bd1", "label": "日本語ラベル" }, 75 "job": { "id": "frontend", "label": "日本語ラベル" }, 76 "bio": "自己紹介" 77 } 78 } 79 """ 80 if not isinstance(data, dict): 81 return data 82 83 attributes = data.get("attributes") 84 if not attributes or not isinstance(attributes, dict): 85 return data 86 87 # organization: { id, label } からid/labelを抽出 88 org_data = attributes.get("organization") 89 if isinstance(org_data, dict): 90 if "id" in org_data and data.get("organization") is None: 91 data["organization"] = org_data["id"] 92 if "label" in org_data and data.get("organization_label") is None: 93 data["organization_label"] = org_data["label"] 94 95 # job: { id, label } からid/labelを抽出 96 job_data = attributes.get("job") 97 if isinstance(job_data, dict): 98 if "id" in job_data and data.get("job") is None: 99 data["job"] = job_data["id"] 100 if "label" in job_data and data.get("job_label") is None: 101 data["job_label"] = job_data["label"] 102 103 # bio: 文字列として直接取得 104 bio_data = attributes.get("bio") 105 if isinstance(bio_data, str) and data.get("bio") is None: 106 data["bio"] = bio_data 107 108 return data
attributesからorganization/jobのid/labelを抽出
chatboardloginのレスポンス形式: { "attributes": { "organization": { "id": "se_bd1", "label": "日本語ラベル" }, "job": { "id": "frontend", "label": "日本語ラベル" }, "bio": "自己紹介" } }
124 def get_organization_label(self) -> Optional[str]: 125 """組織ラベルを取得""" 126 return self.organization_label
組織ラベルを取得
142class DeviceInfo(BaseModel): 143 """デバイス情報""" 144 145 id: Optional[str] = None 146 device_code: str 147 device_name: Optional[str] = None 148 device_type: Optional[str] = None 149 device_info: Optional[Dict[str, Any]] = None 150 is_active: Optional[bool] = None 151 created_at: Optional[str] = None 152 last_used_at: Optional[str] = None
デバイス情報
155class LoginHistoryEntry(BaseModel): 156 """ログイン履歴エントリ""" 157 158 timestamp: int 159 ip_address: str = Field(alias="ipAddress") 160 user_agent: str = Field(alias="userAgent") 161 platform: str 162 location: Optional[Dict[str, Optional[str]]] = None 163 success: bool 164 165 model_config = {"populate_by_name": True}
ログイン履歴エントリ
26class UserPagination(BaseModel): 27 """ページネーション情報""" 28 29 page: int 30 limit: int 31 total: int 32 total_pages: int = Field(alias="totalPages") 33 has_next: bool = Field(alias="hasNext") 34 has_previous: bool = Field(alias="hasPrevious") 35 36 model_config = {"populate_by_name": True}
ページネーション情報
173class MCPTokenInfo(BaseModel): 174 """MCPトークン情報""" 175 176 access_token: str 177 token_type: str 178 expires_at: datetime 179 scopes: List[str]
MCPトークン情報
182class MCPTokenError(BaseModel): 183 """MCPトークン取得エラー 184 185 Attributes: 186 code: エラーコード (CONNECTION_NOT_FOUND, TOKEN_EXPIRED, REFRESH_FAILED等) 187 message: エラーメッセージ 188 """ 189 190 code: str 191 message: str
MCPトークン取得エラー
Attributes: code: エラーコード (CONNECTION_NOT_FOUND, TOKEN_EXPIRED, REFRESH_FAILED等) message: エラーメッセージ
199class GetUsersResult(BaseModel): 200 """ユーザー一覧取得結果 201 202 Attributes: 203 success: 成功したかどうか 204 users: ユーザーリスト 205 pagination: ページネーション情報 206 error: エラーメッセージ(失敗時) 207 error_code: エラーコード(失敗時) 208 """ 209 210 success: bool 211 users: List[ApiUser] = Field(default_factory=list) 212 pagination: Optional[UserPagination] = None 213 error: Optional[str] = None 214 error_code: Optional[str] = None
ユーザー一覧取得結果
Attributes: success: 成功したかどうか users: ユーザーリスト pagination: ページネーション情報 error: エラーメッセージ(失敗時) error_code: エラーコード(失敗時)
217class GetUserResult(BaseModel): 218 """ユーザー詳細取得結果 219 220 Attributes: 221 success: 成功したかどうか 222 user: ユーザー情報 223 devices: デバイス情報リスト(空配列も保持) 224 login_history: ログイン履歴(空配列も保持) 225 error: エラーメッセージ(失敗時) 226 error_code: エラーコード(失敗時) 227 """ 228 229 success: bool 230 user: Optional[ApiUser] = None 231 devices: List[DeviceInfo] = Field(default_factory=list) 232 login_history: List[LoginHistoryEntry] = Field(default_factory=list) 233 is_deleted: bool = False 234 deleted_at: Optional[str] = None 235 deleted_by: Optional[str] = None 236 deletion_reason: Optional[str] = None 237 error: Optional[str] = None 238 error_code: Optional[str] = None
ユーザー詳細取得結果
Attributes: success: 成功したかどうか user: ユーザー情報 devices: デバイス情報リスト(空配列も保持) login_history: ログイン履歴(空配列も保持) error: エラーメッセージ(失敗時) error_code: エラーコード(失敗時)
241class GetMCPTokensResult(BaseModel): 242 """MCPトークン取得結果 243 244 Attributes: 245 success: 成功したかどうか 246 tokens: プロバイダー別トークン情報 247 errors: プロバイダー別エラー情報 248 error: 全体エラーメッセージ(失敗時) 249 error_code: 全体エラーコード(失敗時) 250 251 Example: 252 ```python 253 result = await client.get_mcp_tokens("user-id", ["github", "slack"]) 254 if result.success: 255 for provider, token in result.tokens.items(): 256 print(f"{provider}: expires at {token.expires_at}") 257 if result.errors: 258 for provider, err in result.errors.items(): 259 print(f"{provider} error: {err.message}") 260 ``` 261 """ 262 263 success: bool 264 tokens: Dict[str, MCPTokenInfo] = Field(default_factory=dict) 265 errors: Optional[Dict[str, MCPTokenError]] = None 266 error: Optional[str] = None 267 error_code: Optional[str] = None
MCPトークン取得結果
Attributes: success: 成功したかどうか tokens: プロバイダー別トークン情報 errors: プロバイダー別エラー情報 error: 全体エラーメッセージ(失敗時) error_code: 全体エラーコード(失敗時)
Example:
result = await client.get_mcp_tokens("user-id", ["github", "slack"])
if result.success:
for provider, token in result.tokens.items():
print(f"{provider}: expires at {token.expires_at}")
if result.errors:
for provider, err in result.errors.items():
print(f"{provider} error: {err.message}")