Coverage for src/alprina_cli/api/services/neon_service.py: 0%

339 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 18:07 +0100

1""" 

2Neon Database Service 

3Replaces SupabaseService with direct PostgreSQL access via asyncpg. 

4""" 

5 

6import os 

7import secrets 

8import hashlib 

9import bcrypt 

10import asyncpg 

11from typing import Optional, Dict, Any, List 

12from datetime import datetime, timedelta 

13from loguru import logger 

14 

15 

16class NeonService: 

17 """Service for Neon PostgreSQL database operations.""" 

18 

19 def __init__(self): 

20 """Initialize Neon connection pool.""" 

21 self.database_url = os.getenv("DATABASE_URL") 

22 

23 if not self.database_url: 

24 logger.warning("DATABASE_URL not found - database features disabled") 

25 self.enabled = False 

26 self.pool = None 

27 else: 

28 self.enabled = True 

29 self.pool = None # Created on first use 

30 logger.info("Neon service initialized") 

31 

32 async def get_pool(self) -> asyncpg.Pool: 

33 """Get or create connection pool.""" 

34 if not self.pool: 

35 try: 

36 self.pool = await asyncpg.create_pool( 

37 self.database_url, 

38 min_size=2, 

39 max_size=10, 

40 command_timeout=60 

41 ) 

42 logger.info("✅ Neon connection pool created successfully") 

43 except Exception as e: 

44 logger.error(f"❌ Failed to create Neon connection pool: {e}") 

45 raise 

46 return self.pool 

47 

48 def is_enabled(self) -> bool: 

49 """Check if database is configured.""" 

50 return self.enabled 

51 

52 # ========================================== 

53 # User Management 

54 # ========================================== 

55 

56 async def create_user( 

57 self, 

58 email: str, 

59 password: str, 

60 full_name: Optional[str] = None 

61 ) -> Dict[str, Any]: 

62 """Create a new user with password.""" 

63 if not self.is_enabled(): 

64 raise Exception("Database not configured") 

65 

66 pool = await self.get_pool() 

67 

68 # Hash password 

69 if isinstance(password, str): 

70 password = password.encode('utf-8') 

71 password_hash = bcrypt.hashpw(password, bcrypt.gensalt()).decode('utf-8') 

72 

73 async with pool.acquire() as conn: 

74 user = await conn.fetchrow( 

75 """ 

76 INSERT INTO users (email, password_hash, full_name, tier) 

77 VALUES ($1, $2, $3, 'none') 

78 RETURNING id, email, full_name, tier, created_at 

79 """, 

80 email, password_hash, full_name 

81 ) 

82 

83 # Generate API key 

84 api_key = self.generate_api_key() 

85 await self.create_api_key(str(user['id']), api_key, "Default API Key") 

86 

87 logger.info(f"Created user: {email}") 

88 

89 return { 

90 "user_id": str(user['id']), 

91 "email": user['email'], 

92 "full_name": user['full_name'], 

93 "tier": user['tier'], 

94 "api_key": api_key, 

95 "created_at": user['created_at'].isoformat() 

96 } 

97 

98 async def authenticate_user( 

99 self, 

100 email: str, 

101 password: str 

102 ) -> Optional[Dict[str, Any]]: 

103 """Authenticate user with email/password.""" 

104 if not self.is_enabled(): 

105 return None 

106 

107 pool = await self.get_pool() 

108 

109 async with pool.acquire() as conn: 

110 user = await conn.fetchrow( 

111 "SELECT * FROM users WHERE email = $1", 

112 email 

113 ) 

114 

115 if not user or not user['password_hash']: 

116 return None 

117 

118 # Verify password 

119 if isinstance(password, str): 

120 password = password.encode('utf-8') 

121 

122 if bcrypt.checkpw(password, user['password_hash'].encode('utf-8')): 

123 return dict(user) 

124 

125 return None 

126 

127 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: 

128 """Get user by ID.""" 

129 if not self.is_enabled(): 

130 return None 

131 

132 pool = await self.get_pool() 

133 

134 async with pool.acquire() as conn: 

135 user = await conn.fetchrow( 

136 "SELECT * FROM users WHERE id = $1", 

137 user_id 

138 ) 

139 return dict(user) if user else None 

140 

141 async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: 

142 """Get user by email.""" 

143 if not self.is_enabled(): 

144 return None 

145 

146 pool = await self.get_pool() 

147 

148 async with pool.acquire() as conn: 

149 user = await conn.fetchrow( 

150 "SELECT * FROM users WHERE email = $1", 

151 email 

152 ) 

153 return dict(user) if user else None 

154 

155 async def get_user_by_subscription( 

156 self, 

157 subscription_id: str 

158 ) -> Optional[Dict[str, Any]]: 

159 """Get user by Polar subscription ID.""" 

160 if not self.is_enabled(): 

161 return None 

162 

163 pool = await self.get_pool() 

164 

165 async with pool.acquire() as conn: 

166 user = await conn.fetchrow( 

167 "SELECT * FROM users WHERE polar_subscription_id = $1", 

168 subscription_id 

169 ) 

170 return dict(user) if user else None 

171 

172 async def create_user_from_subscription( 

173 self, 

174 email: str, 

175 polar_customer_id: str, 

176 polar_subscription_id: str, 

177 tier: str, 

178 billing_period: str = "monthly", 

179 has_metering: bool = True, 

180 scans_included: int = 0, 

181 period_start: datetime = None, 

182 period_end: datetime = None, 

183 seats_included: int = 1 

184 ) -> Dict[str, Any]: 

185 """Create user from Polar subscription (no password needed).""" 

186 if not self.is_enabled(): 

187 raise Exception("Database not configured") 

188 

189 pool = await self.get_pool() 

190 

191 async with pool.acquire() as conn: 

192 user = await conn.fetchrow( 

193 """ 

194 INSERT INTO users ( 

195 email, 

196 tier, 

197 polar_customer_id, 

198 polar_subscription_id, 

199 subscription_status, 

200 billing_period, 

201 has_metering, 

202 scans_included, 

203 scans_used_this_period, 

204 period_start, 

205 period_end, 

206 seats_included, 

207 seats_used, 

208 extra_seats 

209 ) 

210 VALUES ($1, $2, $3, $4, 'active', $5, $6, $7, 0, $8, $9, $10, 1, 0) 

211 RETURNING id, email, tier, created_at 

212 """, 

213 email, tier, polar_customer_id, polar_subscription_id, 

214 billing_period, has_metering, scans_included, 

215 period_start, period_end, seats_included 

216 ) 

217 

218 logger.info(f"Created user from subscription: {email}") 

219 

220 return dict(user) 

221 

222 async def update_user( 

223 self, 

224 user_id: str, 

225 updates: Dict[str, Any] 

226 ) -> bool: 

227 """Update user fields.""" 

228 if not self.is_enabled(): 

229 return False 

230 

231 pool = await self.get_pool() 

232 

233 # Build SET clause dynamically 

234 set_parts = [] 

235 values = [] 

236 idx = 1 

237 

238 for key, value in updates.items(): 

239 set_parts.append(f"{key} = ${idx}") 

240 values.append(value) 

241 idx += 1 

242 

243 values.append(user_id) 

244 

245 query = f""" 

246 UPDATE users 

247 SET {', '.join(set_parts)} 

248 WHERE id = ${idx} 

249 """ 

250 

251 async with pool.acquire() as conn: 

252 result = await conn.execute(query, *values) 

253 return result == "UPDATE 1" 

254 

255 async def initialize_usage_tracking( 

256 self, 

257 user_id: str, 

258 tier: str 

259 ): 

260 """Initialize usage tracking for new user.""" 

261 # Get tier limits 

262 from ..services.polar_service import polar_service 

263 limits = polar_service.get_tier_limits(tier) 

264 

265 await self.update_user( 

266 user_id, 

267 { 

268 "requests_per_hour": limits.get("api_requests_per_hour", 0), 

269 "scans_per_month": limits.get("scans_per_month", 0) 

270 } 

271 ) 

272 

273 async def increment_user_scans(self, user_id: str): 

274 """Increment user scan count.""" 

275 if not self.is_enabled(): 

276 return 

277 

278 pool = await self.get_pool() 

279 

280 async with pool.acquire() as conn: 

281 await conn.execute( 

282 """ 

283 UPDATE users 

284 SET scans_per_month = scans_per_month + 1 

285 WHERE id = $1 

286 """, 

287 user_id 

288 ) 

289 

290 # ========================================== 

291 # API Key Management 

292 # ========================================== 

293 

294 def generate_api_key(self) -> str: 

295 """Generate a new API key.""" 

296 return f"alprina_sk_{secrets.token_urlsafe(32)}" 

297 

298 async def create_api_key( 

299 self, 

300 user_id: str, 

301 api_key: str, 

302 name: str, 

303 expires_at: Optional[datetime] = None 

304 ) -> Dict[str, Any]: 

305 """Create a new API key.""" 

306 if not self.is_enabled(): 

307 raise Exception("Database not configured") 

308 

309 pool = await self.get_pool() 

310 

311 key_hash = hashlib.sha256(api_key.encode()).hexdigest() 

312 key_prefix = api_key[:16] # First 16 chars for display 

313 

314 async with pool.acquire() as conn: 

315 key = await conn.fetchrow( 

316 """ 

317 INSERT INTO api_keys (user_id, key_hash, key_prefix, name, expires_at) 

318 VALUES ($1, $2, $3, $4, $5) 

319 RETURNING id, name, key_prefix, created_at 

320 """, 

321 user_id, key_hash, key_prefix, name, expires_at 

322 ) 

323 

324 return { 

325 "id": str(key['id']), 

326 "name": key['name'], 

327 "key_prefix": key['key_prefix'], 

328 "created_at": key['created_at'].isoformat() 

329 } 

330 

331 async def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]: 

332 """Verify API key and return user.""" 

333 if not self.is_enabled(): 

334 return None 

335 

336 pool = await self.get_pool() 

337 

338 key_hash = hashlib.sha256(api_key.encode()).hexdigest() 

339 

340 async with pool.acquire() as conn: 

341 result = await conn.fetchrow( 

342 """ 

343 SELECT u.*, k.id as key_id 

344 FROM users u 

345 JOIN api_keys k ON k.user_id = u.id 

346 WHERE k.key_hash = $1 

347 AND k.is_active = true 

348 AND (k.expires_at IS NULL OR k.expires_at > NOW()) 

349 """, 

350 key_hash 

351 ) 

352 

353 if result: 

354 # Update last_used_at 

355 await conn.execute( 

356 "UPDATE api_keys SET last_used_at = NOW() WHERE id = $1", 

357 result['key_id'] 

358 ) 

359 

360 return dict(result) 

361 

362 return None 

363 

364 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]: 

365 """List user API keys.""" 

366 if not self.is_enabled(): 

367 return [] 

368 

369 pool = await self.get_pool() 

370 

371 async with pool.acquire() as conn: 

372 keys = await conn.fetch( 

373 """ 

374 SELECT id, name, key_prefix, is_active, created_at, last_used_at, expires_at 

375 FROM api_keys 

376 WHERE user_id = $1 

377 ORDER BY created_at DESC 

378 """, 

379 user_id 

380 ) 

381 

382 return [dict(key) for key in keys] 

383 

384 async def deactivate_api_key(self, key_id: str, user_id: str) -> bool: 

385 """Deactivate an API key.""" 

386 if not self.is_enabled(): 

387 return False 

388 

389 pool = await self.get_pool() 

390 

391 async with pool.acquire() as conn: 

392 result = await conn.execute( 

393 """ 

394 UPDATE api_keys 

395 SET is_active = false 

396 WHERE id = $1 AND user_id = $2 

397 """, 

398 key_id, user_id 

399 ) 

400 

401 return result == "UPDATE 1" 

402 

403 # ========================================== 

404 # Scan Management 

405 # ========================================== 

406 

407 async def create_scan( 

408 self, 

409 user_id: str, 

410 scan_type: str, 

411 workflow_mode: str, 

412 metadata: Optional[Dict] = None 

413 ) -> str: 

414 """Create a new scan.""" 

415 if not self.is_enabled(): 

416 raise Exception("Database not configured") 

417 

418 pool = await self.get_pool() 

419 

420 async with pool.acquire() as conn: 

421 scan = await conn.fetchrow( 

422 """ 

423 INSERT INTO scans (user_id, scan_type, workflow_mode, metadata) 

424 VALUES ($1, $2, $3, $4) 

425 RETURNING id 

426 """, 

427 user_id, scan_type, workflow_mode, metadata or {} 

428 ) 

429 

430 return str(scan['id']) 

431 

432 async def save_scan( 

433 self, 

434 scan_id: str, 

435 status: str, 

436 findings: Optional[Dict] = None, 

437 findings_count: int = 0, 

438 metadata: Optional[Dict] = None 

439 ) -> bool: 

440 """Update scan with results.""" 

441 if not self.is_enabled(): 

442 return False 

443 

444 pool = await self.get_pool() 

445 

446 async with pool.acquire() as conn: 

447 result = await conn.execute( 

448 """ 

449 UPDATE scans 

450 SET status = $1, 

451 findings = $2, 

452 findings_count = $3, 

453 metadata = $4, 

454 completed_at = NOW() 

455 WHERE id = $5 

456 """, 

457 status, findings or {}, findings_count, metadata or {}, scan_id 

458 ) 

459 

460 return result == "UPDATE 1" 

461 

462 async def get_scan( 

463 self, 

464 scan_id: str, 

465 user_id: Optional[str] = None 

466 ) -> Optional[Dict[str, Any]]: 

467 """Get scan by ID.""" 

468 if not self.is_enabled(): 

469 return None 

470 

471 pool = await self.get_pool() 

472 

473 async with pool.acquire() as conn: 

474 if user_id: 

475 scan = await conn.fetchrow( 

476 "SELECT * FROM scans WHERE id = $1 AND user_id = $2", 

477 scan_id, user_id 

478 ) 

479 else: 

480 scan = await conn.fetchrow( 

481 "SELECT * FROM scans WHERE id = $1", 

482 scan_id 

483 ) 

484 

485 return dict(scan) if scan else None 

486 

487 async def list_scans( 

488 self, 

489 user_id: str, 

490 limit: int = 10, 

491 offset: int = 0, 

492 scan_type: Optional[str] = None, 

493 workflow_mode: Optional[str] = None 

494 ) -> List[Dict[str, Any]]: 

495 """List user scans with pagination.""" 

496 if not self.is_enabled(): 

497 return [] 

498 

499 pool = await self.get_pool() 

500 

501 query = "SELECT * FROM scans WHERE user_id = $1" 

502 params = [user_id] 

503 idx = 2 

504 

505 if scan_type: 

506 query += f" AND scan_type = ${idx}" 

507 params.append(scan_type) 

508 idx += 1 

509 

510 if workflow_mode: 

511 query += f" AND workflow_mode = ${idx}" 

512 params.append(workflow_mode) 

513 idx += 1 

514 

515 query += f" ORDER BY created_at DESC LIMIT ${idx} OFFSET ${idx+1}" 

516 params.extend([limit, offset]) 

517 

518 async with pool.acquire() as conn: 

519 scans = await conn.fetch(query, *params) 

520 return [dict(scan) for scan in scans] 

521 

522 # ========================================== 

523 # Rate Limiting & Usage Tracking 

524 # ========================================== 

525 

526 async def check_rate_limit(self, user_id: str) -> Dict[str, Any]: 

527 """Check if user is within rate limits.""" 

528 if not self.is_enabled(): 

529 return {"allowed": True, "remaining": 0} 

530 

531 pool = await self.get_pool() 

532 

533 async with pool.acquire() as conn: 

534 # Get user tier limits 

535 user = await conn.fetchrow( 

536 "SELECT tier, requests_per_hour FROM users WHERE id = $1", 

537 user_id 

538 ) 

539 

540 if not user: 

541 return {"allowed": False, "remaining": 0} 

542 

543 # Count requests in last hour 

544 one_hour_ago = datetime.utcnow() - timedelta(hours=1) 

545 count = await conn.fetchval( 

546 """ 

547 SELECT COUNT(*) FROM usage_logs 

548 WHERE user_id = $1 AND created_at > $2 

549 """, 

550 user_id, one_hour_ago 

551 ) 

552 

553 limit = user['requests_per_hour'] 

554 remaining = max(0, limit - count) 

555 allowed = count < limit 

556 

557 return { 

558 "allowed": allowed, 

559 "remaining": remaining, 

560 "limit": limit, 

561 "used": count 

562 } 

563 

564 async def log_request( 

565 self, 

566 user_id: str, 

567 endpoint: str, 

568 method: str, 

569 status_code: int, 

570 duration_ms: float 

571 ): 

572 """Log API request.""" 

573 if not self.is_enabled(): 

574 return 

575 

576 pool = await self.get_pool() 

577 

578 async with pool.acquire() as conn: 

579 await conn.execute( 

580 """ 

581 INSERT INTO usage_logs (user_id, endpoint, method, status_code, duration_ms) 

582 VALUES ($1, $2, $3, $4, $5) 

583 """, 

584 user_id, endpoint, method, status_code, duration_ms 

585 ) 

586 

587 async def get_user_stats(self, user_id: str) -> Dict[str, Any]: 

588 """Get user statistics.""" 

589 if not self.is_enabled(): 

590 return {} 

591 

592 pool = await self.get_pool() 

593 

594 async with pool.acquire() as conn: 

595 # Get user 

596 user = await conn.fetchrow( 

597 "SELECT * FROM users WHERE id = $1", 

598 user_id 

599 ) 

600 

601 if not user: 

602 return {} 

603 

604 # Count scans this month 

605 first_day = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0) 

606 scans_count = await conn.fetchval( 

607 "SELECT COUNT(*) FROM scans WHERE user_id = $1 AND created_at >= $2", 

608 user_id, first_day 

609 ) 

610 

611 # Count total scans 

612 total_scans = await conn.fetchval( 

613 "SELECT COUNT(*) FROM scans WHERE user_id = $1", 

614 user_id 

615 ) 

616 

617 # Count vulnerabilities found 

618 vulns = await conn.fetchval( 

619 "SELECT COALESCE(SUM(findings_count), 0) FROM scans WHERE user_id = $1", 

620 user_id 

621 ) 

622 

623 return { 

624 "tier": user['tier'], 

625 "scans_this_month": scans_count, 

626 "total_scans": total_scans, 

627 "vulnerabilities_found": vulns, 

628 "scans_limit": user['scans_per_month'], 

629 "requests_limit": user['requests_per_hour'] 

630 } 

631 

632 # ========================================== 

633 # Webhook Event Logging 

634 # ========================================== 

635 

636 async def log_webhook_event( 

637 self, 

638 event_type: str, 

639 event_id: str, 

640 payload: Dict[str, Any] 

641 ): 

642 """Log webhook event.""" 

643 if not self.is_enabled(): 

644 return 

645 

646 pool = await self.get_pool() 

647 

648 import json 

649 payload_json = json.dumps(payload) 

650 

651 async with pool.acquire() as conn: 

652 await conn.execute( 

653 """ 

654 INSERT INTO webhook_events (event_type, event_id, payload) 

655 VALUES ($1, $2, $3::jsonb) 

656 ON CONFLICT (event_id) DO NOTHING 

657 """, 

658 event_type, event_id, payload_json 

659 ) 

660 

661 async def mark_webhook_processed(self, event_id: str): 

662 """Mark webhook as processed.""" 

663 if not self.is_enabled(): 

664 return 

665 

666 pool = await self.get_pool() 

667 

668 async with pool.acquire() as conn: 

669 await conn.execute( 

670 """ 

671 UPDATE webhook_events 

672 SET processed = true, processed_at = NOW() 

673 WHERE event_id = $1 

674 """, 

675 event_id 

676 ) 

677 

678 async def mark_webhook_error(self, event_id: str, error: str): 

679 """Mark webhook as errored.""" 

680 if not self.is_enabled(): 

681 return 

682 

683 pool = await self.get_pool() 

684 

685 async with pool.acquire() as conn: 

686 await conn.execute( 

687 """ 

688 UPDATE webhook_events 

689 SET error = $1 

690 WHERE event_id = $2 

691 """, 

692 error, event_id 

693 ) 

694 

695 # ========================================== 

696 # Device Authorization (CLI OAuth) 

697 # ========================================== 

698 

699 def generate_device_codes(self) -> tuple[str, str]: 

700 """Generate device and user codes.""" 

701 device_code = secrets.token_urlsafe(32) 

702 # User code: 4 digits (like GitHub CLI) - easy to type 

703 user_code = ''.join(secrets.choice('0123456789') for _ in range(4)) 

704 return device_code, user_code 

705 

706 async def create_device_authorization(self) -> Dict[str, Any]: 

707 """Create device authorization for CLI.""" 

708 if not self.is_enabled(): 

709 raise Exception("Database not configured") 

710 

711 pool = await self.get_pool() 

712 

713 device_code, user_code = self.generate_device_codes() 

714 expires_at = datetime.utcnow() + timedelta(minutes=15) 

715 

716 async with pool.acquire() as conn: 

717 auth = await conn.fetchrow( 

718 """ 

719 INSERT INTO device_codes (device_code, user_code, expires_at) 

720 VALUES ($1, $2, $3) 

721 RETURNING device_code, user_code, expires_at 

722 """, 

723 device_code, user_code, expires_at 

724 ) 

725 

726 return { 

727 "device_code": auth['device_code'], 

728 "user_code": auth['user_code'], 

729 "verification_uri": "https://www.alprina.com/device", 

730 "expires_in": 900 # 15 minutes 

731 } 

732 

733 async def check_device_authorization( 

734 self, 

735 device_code: str 

736 ) -> Optional[Dict[str, Any]]: 

737 """Check if device has been authorized.""" 

738 if not self.is_enabled(): 

739 return None 

740 

741 pool = await self.get_pool() 

742 

743 async with pool.acquire() as conn: 

744 auth = await conn.fetchrow( 

745 """ 

746 SELECT * FROM device_codes 

747 WHERE device_code = $1 

748 """, 

749 device_code 

750 ) 

751 

752 if not auth: 

753 return None 

754 

755 # Check if expired 

756 if auth['expires_at'] < datetime.utcnow(): 

757 return {"status": "expired"} 

758 

759 # Check if authorized 

760 if auth['authorized'] and auth['user_id']: 

761 return { 

762 "status": "authorized", 

763 "user_id": str(auth['user_id']), 

764 "user_code": auth['user_code'] 

765 } 

766 

767 # Still pending 

768 return {"status": "pending"} 

769 

770 async def authorize_device(self, user_code: str, user_id: str) -> bool: 

771 """Authorize a device.""" 

772 if not self.is_enabled(): 

773 return False 

774 

775 pool = await self.get_pool() 

776 

777 async with pool.acquire() as conn: 

778 result = await conn.execute( 

779 """ 

780 UPDATE device_codes 

781 SET user_id = $1, authorized = true 

782 WHERE user_code = $2 

783 AND expires_at > NOW() 

784 AND authorized = false 

785 """, 

786 user_id, user_code 

787 ) 

788 

789 return result == "UPDATE 1" 

790 

791 # ========================================== 

792 # Team Management 

793 # ========================================== 

794 

795 async def get_team_members(self, owner_id: str) -> List[Dict[str, Any]]: 

796 """Get all team members for a team owner.""" 

797 if not self.is_enabled(): 

798 return [] 

799 

800 pool = await self.get_pool() 

801 

802 async with pool.acquire() as conn: 

803 # Get team members from team_members table 

804 # We need to join with users to get email and other details 

805 members = await conn.fetch( 

806 """ 

807 SELECT  

808 u.id, 

809 u.email, 

810 u.full_name, 

811 tm.role, 

812 tm.created_at as joined_at 

813 FROM team_members tm 

814 JOIN users u ON u.id = tm.user_id 

815 WHERE tm.subscription_id IN ( 

816 SELECT id FROM user_subscriptions  

817 WHERE user_id = $1 

818 ) 

819 ORDER BY tm.created_at ASC 

820 """, 

821 owner_id 

822 ) 

823 

824 return [dict(member) for member in members] 

825 

826 async def get_team_member_by_email( 

827 self, 

828 owner_id: str, 

829 email: str 

830 ) -> Optional[Dict[str, Any]]: 

831 """Check if email is already a team member.""" 

832 if not self.is_enabled(): 

833 return None 

834 

835 pool = await self.get_pool() 

836 

837 async with pool.acquire() as conn: 

838 member = await conn.fetchrow( 

839 """ 

840 SELECT  

841 u.id, 

842 u.email, 

843 tm.role 

844 FROM team_members tm 

845 JOIN users u ON u.id = tm.user_id 

846 WHERE tm.subscription_id IN ( 

847 SELECT id FROM user_subscriptions  

848 WHERE user_id = $1 

849 ) 

850 AND LOWER(u.email) = LOWER($2) 

851 """, 

852 owner_id, email 

853 ) 

854 

855 return dict(member) if member else None 

856 

857 async def create_team_invitation( 

858 self, 

859 owner_id: str, 

860 invitee_email: str, 

861 role: str 

862 ) -> Dict[str, Any]: 

863 """Create a team invitation.""" 

864 if not self.is_enabled(): 

865 raise Exception("Database not configured") 

866 

867 pool = await self.get_pool() 

868 

869 # Generate invitation token 

870 import secrets 

871 invitation_token = secrets.token_urlsafe(32) 

872 

873 # Get owner email 

874 owner = await self.get_user(owner_id) 

875 owner_email = owner.get("email") if owner else None 

876 

877 async with pool.acquire() as conn: 

878 # Create team_invitations table if it doesn't exist 

879 await conn.execute( 

880 """ 

881 CREATE TABLE IF NOT EXISTS team_invitations ( 

882 id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), 

883 owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, 

884 owner_email TEXT, 

885 invitee_email TEXT NOT NULL, 

886 role TEXT NOT NULL CHECK (role IN ('admin', 'member')), 

887 token TEXT NOT NULL UNIQUE, 

888 status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'accepted', 'expired')), 

889 expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days', 

890 created_at TIMESTAMPTZ DEFAULT NOW() 

891 ); 

892  

893 CREATE INDEX IF NOT EXISTS idx_team_invitations_token ON team_invitations(token); 

894 CREATE INDEX IF NOT EXISTS idx_team_invitations_owner ON team_invitations(owner_id); 

895 CREATE INDEX IF NOT EXISTS idx_team_invitations_email ON team_invitations(invitee_email); 

896 """ 

897 ) 

898 

899 invitation = await conn.fetchrow( 

900 """ 

901 INSERT INTO team_invitations ( 

902 owner_id, owner_email, invitee_email, role, token 

903 ) 

904 VALUES ($1, $2, $3, $4, $5) 

905 RETURNING id, owner_id, owner_email, invitee_email, role, token, created_at 

906 """, 

907 owner_id, owner_email, invitee_email, role, invitation_token 

908 ) 

909 

910 return dict(invitation) 

911 

912 async def get_team_invitation(self, token: str) -> Optional[Dict[str, Any]]: 

913 """Get team invitation by token.""" 

914 if not self.is_enabled(): 

915 return None 

916 

917 pool = await self.get_pool() 

918 

919 async with pool.acquire() as conn: 

920 invitation = await conn.fetchrow( 

921 """ 

922 SELECT * FROM team_invitations 

923 WHERE token = $1 

924 AND status = 'pending' 

925 AND expires_at > NOW() 

926 """, 

927 token 

928 ) 

929 

930 return dict(invitation) if invitation else None 

931 

932 async def delete_team_invitation(self, token: str) -> bool: 

933 """Delete or mark invitation as accepted.""" 

934 if not self.is_enabled(): 

935 return False 

936 

937 pool = await self.get_pool() 

938 

939 async with pool.acquire() as conn: 

940 result = await conn.execute( 

941 """ 

942 UPDATE team_invitations 

943 SET status = 'accepted' 

944 WHERE token = $1 

945 """, 

946 token 

947 ) 

948 

949 return result == "UPDATE 1" 

950 

951 async def add_team_member( 

952 self, 

953 owner_id: str, 

954 member_id: str, 

955 role: str 

956 ) -> bool: 

957 """Add a member to the team.""" 

958 if not self.is_enabled(): 

959 return False 

960 

961 pool = await self.get_pool() 

962 

963 async with pool.acquire() as conn: 

964 # Get owner's subscription 

965 subscription = await conn.fetchrow( 

966 """ 

967 SELECT id FROM user_subscriptions 

968 WHERE user_id = $1 

969 ORDER BY created_at DESC 

970 LIMIT 1 

971 """, 

972 owner_id 

973 ) 

974 

975 if not subscription: 

976 logger.error(f"No subscription found for owner {owner_id}") 

977 return False 

978 

979 # Add team member 

980 await conn.execute( 

981 """ 

982 INSERT INTO team_members (subscription_id, user_id, role) 

983 VALUES ($1, $2, $3) 

984 ON CONFLICT (subscription_id, user_id) DO NOTHING 

985 """, 

986 subscription["id"], member_id, role 

987 ) 

988 

989 logger.info(f"Added team member {member_id} to subscription {subscription['id']}") 

990 return True 

991 

992 async def remove_team_member( 

993 self, 

994 owner_id: str, 

995 member_id: str 

996 ) -> bool: 

997 """Remove a team member.""" 

998 if not self.is_enabled(): 

999 return False 

1000 

1001 pool = await self.get_pool() 

1002 

1003 async with pool.acquire() as conn: 

1004 result = await conn.execute( 

1005 """ 

1006 DELETE FROM team_members 

1007 WHERE subscription_id IN ( 

1008 SELECT id FROM user_subscriptions  

1009 WHERE user_id = $1 

1010 ) 

1011 AND user_id = $2 

1012 """, 

1013 owner_id, member_id 

1014 ) 

1015 

1016 return result == "DELETE 1" 

1017 

1018 # ========================================== 

1019 # Cleanup 

1020 # ========================================== 

1021 

1022 async def close(self): 

1023 """Close database connection pool.""" 

1024 if self.pool: 

1025 await self.pool.close() 

1026 logger.info("Neon connection pool closed") 

1027 

1028 

1029# Create singleton instance 

1030neon_service = NeonService()