Coverage for src/alprina_cli/api/services/neon_service.py: 0%
339 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 18:07 +0100
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 18:07 +0100
1"""
2Neon Database Service
3Replaces SupabaseService with direct PostgreSQL access via asyncpg.
4"""
6import os
7import secrets
8import hashlib
9import bcrypt
10import asyncpg
11from typing import Optional, Dict, Any, List
12from datetime import datetime, timedelta
13from loguru import logger
16class NeonService:
17 """Service for Neon PostgreSQL database operations."""
19 def __init__(self):
20 """Initialize Neon connection pool."""
21 self.database_url = os.getenv("DATABASE_URL")
23 if not self.database_url:
24 logger.warning("DATABASE_URL not found - database features disabled")
25 self.enabled = False
26 self.pool = None
27 else:
28 self.enabled = True
29 self.pool = None # Created on first use
30 logger.info("Neon service initialized")
32 async def get_pool(self) -> asyncpg.Pool:
33 """Get or create connection pool."""
34 if not self.pool:
35 try:
36 self.pool = await asyncpg.create_pool(
37 self.database_url,
38 min_size=2,
39 max_size=10,
40 command_timeout=60
41 )
42 logger.info("✅ Neon connection pool created successfully")
43 except Exception as e:
44 logger.error(f"❌ Failed to create Neon connection pool: {e}")
45 raise
46 return self.pool
48 def is_enabled(self) -> bool:
49 """Check if database is configured."""
50 return self.enabled
52 # ==========================================
53 # User Management
54 # ==========================================
56 async def create_user(
57 self,
58 email: str,
59 password: str,
60 full_name: Optional[str] = None
61 ) -> Dict[str, Any]:
62 """Create a new user with password."""
63 if not self.is_enabled():
64 raise Exception("Database not configured")
66 pool = await self.get_pool()
68 # Hash password
69 if isinstance(password, str):
70 password = password.encode('utf-8')
71 password_hash = bcrypt.hashpw(password, bcrypt.gensalt()).decode('utf-8')
73 async with pool.acquire() as conn:
74 user = await conn.fetchrow(
75 """
76 INSERT INTO users (email, password_hash, full_name, tier)
77 VALUES ($1, $2, $3, 'none')
78 RETURNING id, email, full_name, tier, created_at
79 """,
80 email, password_hash, full_name
81 )
83 # Generate API key
84 api_key = self.generate_api_key()
85 await self.create_api_key(str(user['id']), api_key, "Default API Key")
87 logger.info(f"Created user: {email}")
89 return {
90 "user_id": str(user['id']),
91 "email": user['email'],
92 "full_name": user['full_name'],
93 "tier": user['tier'],
94 "api_key": api_key,
95 "created_at": user['created_at'].isoformat()
96 }
98 async def authenticate_user(
99 self,
100 email: str,
101 password: str
102 ) -> Optional[Dict[str, Any]]:
103 """Authenticate user with email/password."""
104 if not self.is_enabled():
105 return None
107 pool = await self.get_pool()
109 async with pool.acquire() as conn:
110 user = await conn.fetchrow(
111 "SELECT * FROM users WHERE email = $1",
112 email
113 )
115 if not user or not user['password_hash']:
116 return None
118 # Verify password
119 if isinstance(password, str):
120 password = password.encode('utf-8')
122 if bcrypt.checkpw(password, user['password_hash'].encode('utf-8')):
123 return dict(user)
125 return None
127 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
128 """Get user by ID."""
129 if not self.is_enabled():
130 return None
132 pool = await self.get_pool()
134 async with pool.acquire() as conn:
135 user = await conn.fetchrow(
136 "SELECT * FROM users WHERE id = $1",
137 user_id
138 )
139 return dict(user) if user else None
141 async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
142 """Get user by email."""
143 if not self.is_enabled():
144 return None
146 pool = await self.get_pool()
148 async with pool.acquire() as conn:
149 user = await conn.fetchrow(
150 "SELECT * FROM users WHERE email = $1",
151 email
152 )
153 return dict(user) if user else None
155 async def get_user_by_subscription(
156 self,
157 subscription_id: str
158 ) -> Optional[Dict[str, Any]]:
159 """Get user by Polar subscription ID."""
160 if not self.is_enabled():
161 return None
163 pool = await self.get_pool()
165 async with pool.acquire() as conn:
166 user = await conn.fetchrow(
167 "SELECT * FROM users WHERE polar_subscription_id = $1",
168 subscription_id
169 )
170 return dict(user) if user else None
172 async def create_user_from_subscription(
173 self,
174 email: str,
175 polar_customer_id: str,
176 polar_subscription_id: str,
177 tier: str,
178 billing_period: str = "monthly",
179 has_metering: bool = True,
180 scans_included: int = 0,
181 period_start: datetime = None,
182 period_end: datetime = None,
183 seats_included: int = 1
184 ) -> Dict[str, Any]:
185 """Create user from Polar subscription (no password needed)."""
186 if not self.is_enabled():
187 raise Exception("Database not configured")
189 pool = await self.get_pool()
191 async with pool.acquire() as conn:
192 user = await conn.fetchrow(
193 """
194 INSERT INTO users (
195 email,
196 tier,
197 polar_customer_id,
198 polar_subscription_id,
199 subscription_status,
200 billing_period,
201 has_metering,
202 scans_included,
203 scans_used_this_period,
204 period_start,
205 period_end,
206 seats_included,
207 seats_used,
208 extra_seats
209 )
210 VALUES ($1, $2, $3, $4, 'active', $5, $6, $7, 0, $8, $9, $10, 1, 0)
211 RETURNING id, email, tier, created_at
212 """,
213 email, tier, polar_customer_id, polar_subscription_id,
214 billing_period, has_metering, scans_included,
215 period_start, period_end, seats_included
216 )
218 logger.info(f"Created user from subscription: {email}")
220 return dict(user)
222 async def update_user(
223 self,
224 user_id: str,
225 updates: Dict[str, Any]
226 ) -> bool:
227 """Update user fields."""
228 if not self.is_enabled():
229 return False
231 pool = await self.get_pool()
233 # Build SET clause dynamically
234 set_parts = []
235 values = []
236 idx = 1
238 for key, value in updates.items():
239 set_parts.append(f"{key} = ${idx}")
240 values.append(value)
241 idx += 1
243 values.append(user_id)
245 query = f"""
246 UPDATE users
247 SET {', '.join(set_parts)}
248 WHERE id = ${idx}
249 """
251 async with pool.acquire() as conn:
252 result = await conn.execute(query, *values)
253 return result == "UPDATE 1"
255 async def initialize_usage_tracking(
256 self,
257 user_id: str,
258 tier: str
259 ):
260 """Initialize usage tracking for new user."""
261 # Get tier limits
262 from ..services.polar_service import polar_service
263 limits = polar_service.get_tier_limits(tier)
265 await self.update_user(
266 user_id,
267 {
268 "requests_per_hour": limits.get("api_requests_per_hour", 0),
269 "scans_per_month": limits.get("scans_per_month", 0)
270 }
271 )
273 async def increment_user_scans(self, user_id: str):
274 """Increment user scan count."""
275 if not self.is_enabled():
276 return
278 pool = await self.get_pool()
280 async with pool.acquire() as conn:
281 await conn.execute(
282 """
283 UPDATE users
284 SET scans_per_month = scans_per_month + 1
285 WHERE id = $1
286 """,
287 user_id
288 )
290 # ==========================================
291 # API Key Management
292 # ==========================================
294 def generate_api_key(self) -> str:
295 """Generate a new API key."""
296 return f"alprina_sk_{secrets.token_urlsafe(32)}"
298 async def create_api_key(
299 self,
300 user_id: str,
301 api_key: str,
302 name: str,
303 expires_at: Optional[datetime] = None
304 ) -> Dict[str, Any]:
305 """Create a new API key."""
306 if not self.is_enabled():
307 raise Exception("Database not configured")
309 pool = await self.get_pool()
311 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
312 key_prefix = api_key[:16] # First 16 chars for display
314 async with pool.acquire() as conn:
315 key = await conn.fetchrow(
316 """
317 INSERT INTO api_keys (user_id, key_hash, key_prefix, name, expires_at)
318 VALUES ($1, $2, $3, $4, $5)
319 RETURNING id, name, key_prefix, created_at
320 """,
321 user_id, key_hash, key_prefix, name, expires_at
322 )
324 return {
325 "id": str(key['id']),
326 "name": key['name'],
327 "key_prefix": key['key_prefix'],
328 "created_at": key['created_at'].isoformat()
329 }
331 async def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
332 """Verify API key and return user."""
333 if not self.is_enabled():
334 return None
336 pool = await self.get_pool()
338 key_hash = hashlib.sha256(api_key.encode()).hexdigest()
340 async with pool.acquire() as conn:
341 result = await conn.fetchrow(
342 """
343 SELECT u.*, k.id as key_id
344 FROM users u
345 JOIN api_keys k ON k.user_id = u.id
346 WHERE k.key_hash = $1
347 AND k.is_active = true
348 AND (k.expires_at IS NULL OR k.expires_at > NOW())
349 """,
350 key_hash
351 )
353 if result:
354 # Update last_used_at
355 await conn.execute(
356 "UPDATE api_keys SET last_used_at = NOW() WHERE id = $1",
357 result['key_id']
358 )
360 return dict(result)
362 return None
364 async def list_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
365 """List user API keys."""
366 if not self.is_enabled():
367 return []
369 pool = await self.get_pool()
371 async with pool.acquire() as conn:
372 keys = await conn.fetch(
373 """
374 SELECT id, name, key_prefix, is_active, created_at, last_used_at, expires_at
375 FROM api_keys
376 WHERE user_id = $1
377 ORDER BY created_at DESC
378 """,
379 user_id
380 )
382 return [dict(key) for key in keys]
384 async def deactivate_api_key(self, key_id: str, user_id: str) -> bool:
385 """Deactivate an API key."""
386 if not self.is_enabled():
387 return False
389 pool = await self.get_pool()
391 async with pool.acquire() as conn:
392 result = await conn.execute(
393 """
394 UPDATE api_keys
395 SET is_active = false
396 WHERE id = $1 AND user_id = $2
397 """,
398 key_id, user_id
399 )
401 return result == "UPDATE 1"
403 # ==========================================
404 # Scan Management
405 # ==========================================
407 async def create_scan(
408 self,
409 user_id: str,
410 scan_type: str,
411 workflow_mode: str,
412 metadata: Optional[Dict] = None
413 ) -> str:
414 """Create a new scan."""
415 if not self.is_enabled():
416 raise Exception("Database not configured")
418 pool = await self.get_pool()
420 async with pool.acquire() as conn:
421 scan = await conn.fetchrow(
422 """
423 INSERT INTO scans (user_id, scan_type, workflow_mode, metadata)
424 VALUES ($1, $2, $3, $4)
425 RETURNING id
426 """,
427 user_id, scan_type, workflow_mode, metadata or {}
428 )
430 return str(scan['id'])
432 async def save_scan(
433 self,
434 scan_id: str,
435 status: str,
436 findings: Optional[Dict] = None,
437 findings_count: int = 0,
438 metadata: Optional[Dict] = None
439 ) -> bool:
440 """Update scan with results."""
441 if not self.is_enabled():
442 return False
444 pool = await self.get_pool()
446 async with pool.acquire() as conn:
447 result = await conn.execute(
448 """
449 UPDATE scans
450 SET status = $1,
451 findings = $2,
452 findings_count = $3,
453 metadata = $4,
454 completed_at = NOW()
455 WHERE id = $5
456 """,
457 status, findings or {}, findings_count, metadata or {}, scan_id
458 )
460 return result == "UPDATE 1"
462 async def get_scan(
463 self,
464 scan_id: str,
465 user_id: Optional[str] = None
466 ) -> Optional[Dict[str, Any]]:
467 """Get scan by ID."""
468 if not self.is_enabled():
469 return None
471 pool = await self.get_pool()
473 async with pool.acquire() as conn:
474 if user_id:
475 scan = await conn.fetchrow(
476 "SELECT * FROM scans WHERE id = $1 AND user_id = $2",
477 scan_id, user_id
478 )
479 else:
480 scan = await conn.fetchrow(
481 "SELECT * FROM scans WHERE id = $1",
482 scan_id
483 )
485 return dict(scan) if scan else None
487 async def list_scans(
488 self,
489 user_id: str,
490 limit: int = 10,
491 offset: int = 0,
492 scan_type: Optional[str] = None,
493 workflow_mode: Optional[str] = None
494 ) -> List[Dict[str, Any]]:
495 """List user scans with pagination."""
496 if not self.is_enabled():
497 return []
499 pool = await self.get_pool()
501 query = "SELECT * FROM scans WHERE user_id = $1"
502 params = [user_id]
503 idx = 2
505 if scan_type:
506 query += f" AND scan_type = ${idx}"
507 params.append(scan_type)
508 idx += 1
510 if workflow_mode:
511 query += f" AND workflow_mode = ${idx}"
512 params.append(workflow_mode)
513 idx += 1
515 query += f" ORDER BY created_at DESC LIMIT ${idx} OFFSET ${idx+1}"
516 params.extend([limit, offset])
518 async with pool.acquire() as conn:
519 scans = await conn.fetch(query, *params)
520 return [dict(scan) for scan in scans]
522 # ==========================================
523 # Rate Limiting & Usage Tracking
524 # ==========================================
526 async def check_rate_limit(self, user_id: str) -> Dict[str, Any]:
527 """Check if user is within rate limits."""
528 if not self.is_enabled():
529 return {"allowed": True, "remaining": 0}
531 pool = await self.get_pool()
533 async with pool.acquire() as conn:
534 # Get user tier limits
535 user = await conn.fetchrow(
536 "SELECT tier, requests_per_hour FROM users WHERE id = $1",
537 user_id
538 )
540 if not user:
541 return {"allowed": False, "remaining": 0}
543 # Count requests in last hour
544 one_hour_ago = datetime.utcnow() - timedelta(hours=1)
545 count = await conn.fetchval(
546 """
547 SELECT COUNT(*) FROM usage_logs
548 WHERE user_id = $1 AND created_at > $2
549 """,
550 user_id, one_hour_ago
551 )
553 limit = user['requests_per_hour']
554 remaining = max(0, limit - count)
555 allowed = count < limit
557 return {
558 "allowed": allowed,
559 "remaining": remaining,
560 "limit": limit,
561 "used": count
562 }
564 async def log_request(
565 self,
566 user_id: str,
567 endpoint: str,
568 method: str,
569 status_code: int,
570 duration_ms: float
571 ):
572 """Log API request."""
573 if not self.is_enabled():
574 return
576 pool = await self.get_pool()
578 async with pool.acquire() as conn:
579 await conn.execute(
580 """
581 INSERT INTO usage_logs (user_id, endpoint, method, status_code, duration_ms)
582 VALUES ($1, $2, $3, $4, $5)
583 """,
584 user_id, endpoint, method, status_code, duration_ms
585 )
587 async def get_user_stats(self, user_id: str) -> Dict[str, Any]:
588 """Get user statistics."""
589 if not self.is_enabled():
590 return {}
592 pool = await self.get_pool()
594 async with pool.acquire() as conn:
595 # Get user
596 user = await conn.fetchrow(
597 "SELECT * FROM users WHERE id = $1",
598 user_id
599 )
601 if not user:
602 return {}
604 # Count scans this month
605 first_day = datetime.utcnow().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
606 scans_count = await conn.fetchval(
607 "SELECT COUNT(*) FROM scans WHERE user_id = $1 AND created_at >= $2",
608 user_id, first_day
609 )
611 # Count total scans
612 total_scans = await conn.fetchval(
613 "SELECT COUNT(*) FROM scans WHERE user_id = $1",
614 user_id
615 )
617 # Count vulnerabilities found
618 vulns = await conn.fetchval(
619 "SELECT COALESCE(SUM(findings_count), 0) FROM scans WHERE user_id = $1",
620 user_id
621 )
623 return {
624 "tier": user['tier'],
625 "scans_this_month": scans_count,
626 "total_scans": total_scans,
627 "vulnerabilities_found": vulns,
628 "scans_limit": user['scans_per_month'],
629 "requests_limit": user['requests_per_hour']
630 }
632 # ==========================================
633 # Webhook Event Logging
634 # ==========================================
636 async def log_webhook_event(
637 self,
638 event_type: str,
639 event_id: str,
640 payload: Dict[str, Any]
641 ):
642 """Log webhook event."""
643 if not self.is_enabled():
644 return
646 pool = await self.get_pool()
648 import json
649 payload_json = json.dumps(payload)
651 async with pool.acquire() as conn:
652 await conn.execute(
653 """
654 INSERT INTO webhook_events (event_type, event_id, payload)
655 VALUES ($1, $2, $3::jsonb)
656 ON CONFLICT (event_id) DO NOTHING
657 """,
658 event_type, event_id, payload_json
659 )
661 async def mark_webhook_processed(self, event_id: str):
662 """Mark webhook as processed."""
663 if not self.is_enabled():
664 return
666 pool = await self.get_pool()
668 async with pool.acquire() as conn:
669 await conn.execute(
670 """
671 UPDATE webhook_events
672 SET processed = true, processed_at = NOW()
673 WHERE event_id = $1
674 """,
675 event_id
676 )
678 async def mark_webhook_error(self, event_id: str, error: str):
679 """Mark webhook as errored."""
680 if not self.is_enabled():
681 return
683 pool = await self.get_pool()
685 async with pool.acquire() as conn:
686 await conn.execute(
687 """
688 UPDATE webhook_events
689 SET error = $1
690 WHERE event_id = $2
691 """,
692 error, event_id
693 )
695 # ==========================================
696 # Device Authorization (CLI OAuth)
697 # ==========================================
699 def generate_device_codes(self) -> tuple[str, str]:
700 """Generate device and user codes."""
701 device_code = secrets.token_urlsafe(32)
702 # User code: 4 digits (like GitHub CLI) - easy to type
703 user_code = ''.join(secrets.choice('0123456789') for _ in range(4))
704 return device_code, user_code
706 async def create_device_authorization(self) -> Dict[str, Any]:
707 """Create device authorization for CLI."""
708 if not self.is_enabled():
709 raise Exception("Database not configured")
711 pool = await self.get_pool()
713 device_code, user_code = self.generate_device_codes()
714 expires_at = datetime.utcnow() + timedelta(minutes=15)
716 async with pool.acquire() as conn:
717 auth = await conn.fetchrow(
718 """
719 INSERT INTO device_codes (device_code, user_code, expires_at)
720 VALUES ($1, $2, $3)
721 RETURNING device_code, user_code, expires_at
722 """,
723 device_code, user_code, expires_at
724 )
726 return {
727 "device_code": auth['device_code'],
728 "user_code": auth['user_code'],
729 "verification_uri": "https://www.alprina.com/device",
730 "expires_in": 900 # 15 minutes
731 }
733 async def check_device_authorization(
734 self,
735 device_code: str
736 ) -> Optional[Dict[str, Any]]:
737 """Check if device has been authorized."""
738 if not self.is_enabled():
739 return None
741 pool = await self.get_pool()
743 async with pool.acquire() as conn:
744 auth = await conn.fetchrow(
745 """
746 SELECT * FROM device_codes
747 WHERE device_code = $1
748 """,
749 device_code
750 )
752 if not auth:
753 return None
755 # Check if expired
756 if auth['expires_at'] < datetime.utcnow():
757 return {"status": "expired"}
759 # Check if authorized
760 if auth['authorized'] and auth['user_id']:
761 return {
762 "status": "authorized",
763 "user_id": str(auth['user_id']),
764 "user_code": auth['user_code']
765 }
767 # Still pending
768 return {"status": "pending"}
770 async def authorize_device(self, user_code: str, user_id: str) -> bool:
771 """Authorize a device."""
772 if not self.is_enabled():
773 return False
775 pool = await self.get_pool()
777 async with pool.acquire() as conn:
778 result = await conn.execute(
779 """
780 UPDATE device_codes
781 SET user_id = $1, authorized = true
782 WHERE user_code = $2
783 AND expires_at > NOW()
784 AND authorized = false
785 """,
786 user_id, user_code
787 )
789 return result == "UPDATE 1"
791 # ==========================================
792 # Team Management
793 # ==========================================
795 async def get_team_members(self, owner_id: str) -> List[Dict[str, Any]]:
796 """Get all team members for a team owner."""
797 if not self.is_enabled():
798 return []
800 pool = await self.get_pool()
802 async with pool.acquire() as conn:
803 # Get team members from team_members table
804 # We need to join with users to get email and other details
805 members = await conn.fetch(
806 """
807 SELECT
808 u.id,
809 u.email,
810 u.full_name,
811 tm.role,
812 tm.created_at as joined_at
813 FROM team_members tm
814 JOIN users u ON u.id = tm.user_id
815 WHERE tm.subscription_id IN (
816 SELECT id FROM user_subscriptions
817 WHERE user_id = $1
818 )
819 ORDER BY tm.created_at ASC
820 """,
821 owner_id
822 )
824 return [dict(member) for member in members]
826 async def get_team_member_by_email(
827 self,
828 owner_id: str,
829 email: str
830 ) -> Optional[Dict[str, Any]]:
831 """Check if email is already a team member."""
832 if not self.is_enabled():
833 return None
835 pool = await self.get_pool()
837 async with pool.acquire() as conn:
838 member = await conn.fetchrow(
839 """
840 SELECT
841 u.id,
842 u.email,
843 tm.role
844 FROM team_members tm
845 JOIN users u ON u.id = tm.user_id
846 WHERE tm.subscription_id IN (
847 SELECT id FROM user_subscriptions
848 WHERE user_id = $1
849 )
850 AND LOWER(u.email) = LOWER($2)
851 """,
852 owner_id, email
853 )
855 return dict(member) if member else None
857 async def create_team_invitation(
858 self,
859 owner_id: str,
860 invitee_email: str,
861 role: str
862 ) -> Dict[str, Any]:
863 """Create a team invitation."""
864 if not self.is_enabled():
865 raise Exception("Database not configured")
867 pool = await self.get_pool()
869 # Generate invitation token
870 import secrets
871 invitation_token = secrets.token_urlsafe(32)
873 # Get owner email
874 owner = await self.get_user(owner_id)
875 owner_email = owner.get("email") if owner else None
877 async with pool.acquire() as conn:
878 # Create team_invitations table if it doesn't exist
879 await conn.execute(
880 """
881 CREATE TABLE IF NOT EXISTS team_invitations (
882 id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
883 owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
884 owner_email TEXT,
885 invitee_email TEXT NOT NULL,
886 role TEXT NOT NULL CHECK (role IN ('admin', 'member')),
887 token TEXT NOT NULL UNIQUE,
888 status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'accepted', 'expired')),
889 expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '7 days',
890 created_at TIMESTAMPTZ DEFAULT NOW()
891 );
893 CREATE INDEX IF NOT EXISTS idx_team_invitations_token ON team_invitations(token);
894 CREATE INDEX IF NOT EXISTS idx_team_invitations_owner ON team_invitations(owner_id);
895 CREATE INDEX IF NOT EXISTS idx_team_invitations_email ON team_invitations(invitee_email);
896 """
897 )
899 invitation = await conn.fetchrow(
900 """
901 INSERT INTO team_invitations (
902 owner_id, owner_email, invitee_email, role, token
903 )
904 VALUES ($1, $2, $3, $4, $5)
905 RETURNING id, owner_id, owner_email, invitee_email, role, token, created_at
906 """,
907 owner_id, owner_email, invitee_email, role, invitation_token
908 )
910 return dict(invitation)
912 async def get_team_invitation(self, token: str) -> Optional[Dict[str, Any]]:
913 """Get team invitation by token."""
914 if not self.is_enabled():
915 return None
917 pool = await self.get_pool()
919 async with pool.acquire() as conn:
920 invitation = await conn.fetchrow(
921 """
922 SELECT * FROM team_invitations
923 WHERE token = $1
924 AND status = 'pending'
925 AND expires_at > NOW()
926 """,
927 token
928 )
930 return dict(invitation) if invitation else None
932 async def delete_team_invitation(self, token: str) -> bool:
933 """Delete or mark invitation as accepted."""
934 if not self.is_enabled():
935 return False
937 pool = await self.get_pool()
939 async with pool.acquire() as conn:
940 result = await conn.execute(
941 """
942 UPDATE team_invitations
943 SET status = 'accepted'
944 WHERE token = $1
945 """,
946 token
947 )
949 return result == "UPDATE 1"
951 async def add_team_member(
952 self,
953 owner_id: str,
954 member_id: str,
955 role: str
956 ) -> bool:
957 """Add a member to the team."""
958 if not self.is_enabled():
959 return False
961 pool = await self.get_pool()
963 async with pool.acquire() as conn:
964 # Get owner's subscription
965 subscription = await conn.fetchrow(
966 """
967 SELECT id FROM user_subscriptions
968 WHERE user_id = $1
969 ORDER BY created_at DESC
970 LIMIT 1
971 """,
972 owner_id
973 )
975 if not subscription:
976 logger.error(f"No subscription found for owner {owner_id}")
977 return False
979 # Add team member
980 await conn.execute(
981 """
982 INSERT INTO team_members (subscription_id, user_id, role)
983 VALUES ($1, $2, $3)
984 ON CONFLICT (subscription_id, user_id) DO NOTHING
985 """,
986 subscription["id"], member_id, role
987 )
989 logger.info(f"Added team member {member_id} to subscription {subscription['id']}")
990 return True
992 async def remove_team_member(
993 self,
994 owner_id: str,
995 member_id: str
996 ) -> bool:
997 """Remove a team member."""
998 if not self.is_enabled():
999 return False
1001 pool = await self.get_pool()
1003 async with pool.acquire() as conn:
1004 result = await conn.execute(
1005 """
1006 DELETE FROM team_members
1007 WHERE subscription_id IN (
1008 SELECT id FROM user_subscriptions
1009 WHERE user_id = $1
1010 )
1011 AND user_id = $2
1012 """,
1013 owner_id, member_id
1014 )
1016 return result == "DELETE 1"
1018 # ==========================================
1019 # Cleanup
1020 # ==========================================
1022 async def close(self):
1023 """Close database connection pool."""
1024 if self.pool:
1025 await self.pool.close()
1026 logger.info("Neon connection pool closed")
1029# Create singleton instance
1030neon_service = NeonService()