Source code for subscriptionkore.services.entitlement_service

"""Entitlement resolution service."""

from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import structlog

from subscriptionkore.core.exceptions import EntitlementNotFoundError, UsageLimitExceededError
from subscriptionkore.core.models import (
    CustomerEntitlement,
    EntitlementOverride,
    EntitlementSource,
    EntitlementValueType,
    SubscriptionStatus,
)

if TYPE_CHECKING:
    from subscriptionkore.ports.cache import CachePort
    from subscriptionkore.ports.repository import (
        EntitlementOverrideRepository,
        EntitlementRepository,
        PlanRepository,
        SubscriptionRepository,
    )

logger = structlog.get_logger()


[docs] class EntitlementService: """ Resolves and manages customer entitlements. Entitlements are derived from: 1. Customer-level overrides (highest priority) 2. Active subscriptionkore plan entitlements 3. Trial entitlements 4. Default entitlements (lowest priority) """ def __init__( self, entitlement_repo: EntitlementRepository, override_repo: EntitlementOverrideRepository, subscriptionkore_repo: SubscriptionRepository, plan_repo: PlanRepository, cache: CachePort | None = None, cache_ttl: int = 300, ) -> None: self._entitlement_repo = entitlement_repo self._override_repo = override_repo self._subscriptionkore_repo = subscriptionkore_repo self._plan_repo = plan_repo self._cache = cache self._cache_ttl = cache_ttl def _cache_key(self, customer_id: str) -> str: return f"entitlements:{customer_id}"
[docs] async def check( self, customer_id: str, entitlement_key: str, ) -> CustomerEntitlement: """ Check a specific entitlement for a customer. Args: customer_id: Customer ID entitlement_key: Entitlement key to check Returns: Resolved CustomerEntitlement Raises: EntitlementNotFoundError: If entitlement key doesn't exist """ all_entitlements = await self.check_all(customer_id) for ent in all_entitlements: if ent.entitlement_key == entitlement_key: return ent # Check if entitlement definition exists entitlement_def = await self._entitlement_repo.get_by_key(entitlement_key) if entitlement_def is None: raise EntitlementNotFoundError(entitlement_key) # Return default value return CustomerEntitlement( customer_id=customer_id, entitlement_key=entitlement_key, current_value=entitlement_def.default_value or False, value_type=entitlement_def.value_type, source=EntitlementSource.DEFAULT, )
[docs] async def check_many( self, customer_id: str, entitlement_keys: list[str], ) -> dict[str, CustomerEntitlement]: """Check multiple entitlements at once.""" result: dict[str, CustomerEntitlement] = {} for key in entitlement_keys: result[key] = await self.check(customer_id, key) return result
[docs] async def check_all(self, customer_id: str) -> list[CustomerEntitlement]: """ Get all resolved entitlements for a customer. Results are cached for performance. """ log = logger.bind(customer_id=customer_id) # Check cache if self._cache: cached = await self._cache.get(self._cache_key(customer_id)) if cached is not None: log.debug("Entitlements cache hit") return [CustomerEntitlement.model_validate(e) for e in cached] log.debug("Resolving entitlements") # Collect entitlements from all sources entitlements: dict[str, CustomerEntitlement] = {} # 1. Load defaults from entitlement definitions all_definitions = await self._entitlement_repo.list_all() for definition in all_definitions: if definition.default_value is not None: entitlements[definition.key] = CustomerEntitlement( customer_id=customer_id, entitlement_key=definition.key, current_value=definition.default_value, value_type=definition.value_type, source=EntitlementSource.DEFAULT, ) # 2. Load from active subscriptionkores (overrides defaults) subscriptionkores = await self._subscriptionkore_repo.list_active_by_customer(customer_id) for subscriptionkore in subscriptionkores: plan = await self._plan_repo.get(subscriptionkore.plan_id) if plan is None: continue source = ( EntitlementSource.TRIAL if subscriptionkore.status == SubscriptionStatus.TRIALING else EntitlementSource.PLAN ) expires_at = subscriptionkore.trial_end if subscriptionkore.is_trialing() else None for plan_ent in plan.entitlements: existing = entitlements.get(plan_ent.entitlement_key) # Resolve value based on type new_value = plan_ent.value if existing and plan_ent.value is not None: new_value = self._resolve_value( existing.current_value, plan_ent.value, plan_ent.value_type, ) if new_value is not None: entitlements[plan_ent.entitlement_key] = CustomerEntitlement( customer_id=customer_id, entitlement_key=plan_ent.entitlement_key, current_value=new_value, value_type=plan_ent.value_type, source=source, expires_at=expires_at, subscriptionkore_id=subscriptionkore.id, plan_id=plan.id, ) # 3. Load overrides (highest priority) overrides = await self._override_repo.list_by_customer( customer_id=customer_id, include_expired=False, ) for override in overrides: if override.is_expired(): continue entitlements[override.entitlement_key] = CustomerEntitlement( customer_id=customer_id, entitlement_key=override.entitlement_key, current_value=override.value, value_type=override.value_type, source=EntitlementSource.OVERRIDE, expires_at=override.expires_at, ) result = list(entitlements.values()) # Cache result if self._cache: # Find shortest expiration for TTL min_ttl = self._cache_ttl for ent in result: if ent.expires_at: remaining = (ent.expires_at - datetime.utcnow()).total_seconds() if remaining > 0: min_ttl = min(min_ttl, int(remaining)) await self._cache.set( self._cache_key(customer_id), [e.model_dump() for e in result], ttl_seconds=min_ttl, ) return result
def _resolve_value( self, existing: bool | int | str, new: bool | int | str | None, value_type: EntitlementValueType, ) -> bool | int | str | None: """Resolve value when multiple sources provide the same entitlement.""" if new is None: return existing if value_type == EntitlementValueType.BOOLEAN: # Any true source wins return bool(existing) or bool(new) if value_type == EntitlementValueType.NUMERIC: # Maximum value wins return max(int(existing), int(new)) if value_type == EntitlementValueType.UNLIMITED: # Unlimited always wins return new # String: new value wins (higher priority source) return new
[docs] async def has_access( self, customer_id: str, entitlement_key: str, ) -> bool: """Check if customer has access to a feature.""" try: ent = await self.check(customer_id, entitlement_key) return ent.as_bool() except EntitlementNotFoundError: return False
[docs] async def get_limit( self, customer_id: str, entitlement_key: str, ) -> int | None: """ Get numeric limit for an entitlement. Returns None for unlimited entitlements. """ ent = await self.check(customer_id, entitlement_key) return ent.as_int()
[docs] async def check_within_limit( self, customer_id: str, entitlement_key: str, current_usage: int, requested: int = 1, ) -> tuple[bool, int | None]: """ Check if usage is within limit. Args: customer_id: Customer ID entitlement_key: Entitlement key current_usage: Current usage count requested: Additional usage being requested Returns: Tuple of (within_limit, remaining) remaining is None for unlimited """ limit = await self.get_limit(customer_id, entitlement_key) if limit is None: # Unlimited return True, None new_usage = current_usage + requested within_limit = new_usage <= limit remaining = max(0, limit - current_usage) return within_limit, remaining
[docs] async def enforce_limit( self, customer_id: str, entitlement_key: str, current_usage: int, requested: int = 1, ) -> int | None: """ Enforce usage limit, raising exception if exceeded. Args: customer_id: Customer ID entitlement_key: Entitlement key current_usage: Current usage count requested: Additional usage being requested Returns: Remaining quota (None for unlimited) Raises: UsageLimitExceededError: If limit would be exceeded """ limit = await self.get_limit(customer_id, entitlement_key) if limit is None: return None new_usage = current_usage + requested if new_usage > limit: raise UsageLimitExceededError( entitlement_key=entitlement_key, limit=limit, current_usage=current_usage, requested=requested, ) return limit - new_usage
# Override Management
[docs] async def grant_override( self, customer_id: str, entitlement_key: str, value: bool | int | str, expires_at: datetime | None = None, reason: str | None = None, created_by: str | None = None, ) -> EntitlementOverride: """ Grant an entitlement override to a customer. Args: customer_id: Customer ID entitlement_key: Entitlement key value: Override value expires_at: When override expires (None for permanent) reason: Reason for override created_by: Who created the override Returns: Created override """ log = logger.bind(customer_id=customer_id, entitlement_key=entitlement_key) log.info("Granting entitlement override") # Get entitlement definition entitlement_def = await self._entitlement_repo.get_by_key(entitlement_key) if entitlement_def is None: raise EntitlementNotFoundError(entitlement_key) # Delete existing override if any await self._override_repo.delete_by_customer_and_key(customer_id, entitlement_key) # Create new override override = EntitlementOverride( customer_id=customer_id, entitlement_key=entitlement_key, value=value, value_type=entitlement_def.value_type, reason=reason, expires_at=expires_at, created_by=created_by, ) override = await self._override_repo.save(override) # Invalidate cache await self.invalidate(customer_id) log.info("Entitlement override granted", override_id=override.id) return override
[docs] async def revoke_override( self, customer_id: str, entitlement_key: str, ) -> bool: """ Revoke an entitlement override. Returns: True if override was deleted """ log = logger.bind(customer_id=customer_id, entitlement_key=entitlement_key) log.info("Revoking entitlement override") deleted = await self._override_repo.delete_by_customer_and_key(customer_id, entitlement_key) if deleted: await self.invalidate(customer_id) log.info("Entitlement override revoked", deleted=deleted) return deleted
[docs] async def list_overrides( self, customer_id: str, include_expired: bool = False, ) -> list[EntitlementOverride]: """List all overrides for a customer.""" return await self._override_repo.list_by_customer( customer_id=customer_id, include_expired=include_expired, )
# Cache Management
[docs] async def invalidate(self, customer_id: str) -> None: """Invalidate entitlement cache for a customer.""" if self._cache: await self._cache.delete(self._cache_key(customer_id)) logger.debug("Entitlement cache invalidated", customer_id=customer_id)
[docs] async def invalidate_all(self) -> None: """Invalidate all entitlement caches.""" if self._cache: count = await self._cache.delete_pattern("entitlements:*") logger.info("All entitlement caches invalidated", count=count)