"""Stripe payment provider adapter."""
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any
import httpx
import structlog
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from subscriptionkore.config import StripeConfig
from subscriptionkore.core.exceptions import (
ProviderAPIError,
ProviderAuthenticationError,
ProviderNetworkError,
ProviderRateLimitError,
)
from subscriptionkore.core.models import (
Customer,
Invoice,
InvoiceStatus,
Plan,
Product,
ProviderReference,
ProviderType,
Subscription,
SubscriptionStatus,
)
from subscriptionkore.core.models.invoice import InvoiceLineItem
from subscriptionkore.core.models.subscription import (
AppliedDiscount,
PauseBehavior,
PauseConfig,
)
from subscriptionkore.core.models.value_objects import (
BillingPeriod,
Currency,
DateRange,
Interval,
Money,
)
from subscriptionkore.ports.provider import (
ChangePlanRequest,
ChangePreview,
CheckoutRequest,
CheckoutSession,
CreateSubscriptionRequest,
DiscountRequest,
PaymentProviderPort,
PortalSession,
ProrationBehavior,
ProviderCapabilities,
ProviderWebhookEvent,
UpdateSubscriptionRequest,
)
if TYPE_CHECKING:
pass
logger = structlog.get_logger()
[docs]
class StripeAdapter(PaymentProviderPort):
"""Stripe payment provider implementation."""
BASE_URL = "https://api.stripe.com/v1"
STATUS_MAP = {
"incomplete": SubscriptionStatus.INCOMPLETE,
"incomplete_expired": SubscriptionStatus.INCOMPLETE_EXPIRED,
"trialing": SubscriptionStatus.TRIALING,
"active": SubscriptionStatus.ACTIVE,
"past_due": SubscriptionStatus.PAST_DUE,
"canceled": SubscriptionStatus.CANCELED,
"unpaid": SubscriptionStatus.UNPAID,
"paused": SubscriptionStatus.PAUSED,
}
INVOICE_STATUS_MAP = {
"draft": InvoiceStatus.DRAFT,
"open": InvoiceStatus.OPEN,
"paid": InvoiceStatus.PAID,
"void": InvoiceStatus.VOID,
"uncollectible": InvoiceStatus.UNCOLLECTIBLE,
}
def __init__(self, config: StripeConfig) -> None:
self._config = config
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(
base_url=self.BASE_URL,
headers={
"Authorization": f"Bearer {self._config.api_key}",
"Stripe-Version": self._config.api_version,
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=30.0,
)
return self._client
[docs]
async def close(self) -> None:
if self._client:
await self._client.aclose()
self._client = None
@property
def provider_type(self) -> ProviderType:
return ProviderType.STRIPE
@property
def capabilities(self) -> ProviderCapabilities:
return ProviderCapabilities(
supports_pausing=True,
supports_trials=True,
supports_quantity=True,
supports_immediate_cancel=True,
supports_proration=True,
supports_coupons=True,
supports_metered_billing=True,
supports_customer_portal=True,
supports_checkout_sessions=True,
)
def _handle_error(self, response: httpx.Response) -> None:
"""Handle Stripe API errors."""
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
raise ProviderRateLimitError(
provider="stripe",
retry_after=int(retry_after) if retry_after else None,
)
if response.status_code == 401:
raise ProviderAuthenticationError(provider="stripe")
if response.status_code >= 400:
try:
error_data = response.json().get("error", {})
raise ProviderAPIError(
message=error_data.get("message", "Unknown Stripe error"),
provider="stripe",
status_code=response.status_code,
provider_message=error_data.get("message"),
provider_code=error_data.get("code"),
)
except Exception as e:
if isinstance(e, ProviderAPIError):
raise
raise ProviderAPIError(
message=f"Stripe API error: {response.status_code}",
provider="stripe",
status_code=response.status_code,
) from e
@retry(
retry=retry_if_exception_type((ProviderRateLimitError, ProviderNetworkError)),
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=1, max=60),
)
async def _request(
self,
method: str,
endpoint: str,
data: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Make authenticated request to Stripe API."""
client = await self._get_client()
try:
response = await client.request(
method=method,
url=endpoint,
data=data,
params=params,
)
except httpx.NetworkError as e:
raise ProviderNetworkError(provider="stripe", original_error=e) from e
self._handle_error(response)
return response.json()
# Customer Operations
[docs]
async def create_customer(self, customer: Customer) -> ProviderReference:
data: dict[str, Any] = {
"email": customer.email,
"metadata[subscriptionkore_id]": customer.id,
"metadata[external_id]": customer.external_id,
}
if customer.name:
data["name"] = customer.name
if customer.billing_address:
addr = customer.billing_address
if addr.line1:
data["address[line1]"] = addr.line1
if addr.line2:
data["address[line2]"] = addr.line2
if addr.city:
data["address[city]"] = addr.city
if addr.state:
data["address[state]"] = addr.state
if addr.postal_code:
data["address[postal_code]"] = addr.postal_code
if addr.country:
data["address[country]"] = addr.country
result = await self._request("POST", "/customers", data=data)
return ProviderReference(
provider=ProviderType.STRIPE,
external_id=result["id"],
metadata={"email": result.get("email")},
)
[docs]
async def update_customer(self, customer: Customer) -> None:
provider_ref = customer.get_provider_ref("stripe")
if provider_ref is None:
raise ProviderAPIError(
message="Customer has no Stripe reference",
provider="stripe",
status_code=400,
)
data: dict[str, Any] = {"email": customer.email}
if customer.name:
data["name"] = customer.name
await self._request("POST", f"/customers/{provider_ref.external_id}", data=data)
[docs]
async def delete_customer(self, provider_ref: ProviderReference) -> None:
await self._request("DELETE", f"/customers/{provider_ref.external_id}")
[docs]
async def get_customer(self, provider_ref: ProviderReference) -> Customer:
result = await self._request("GET", f"/customers/{provider_ref.external_id}")
from subscriptionkore.core.models.customer import Address
address = None
if result.get("address"):
addr = result["address"]
address = Address(
line1=addr.get("line1"),
line2=addr.get("line2"),
city=addr.get("city"),
state=addr.get("state"),
postal_code=addr.get("postal_code"),
country=addr.get("country"),
)
return Customer(
id=result.get("metadata", {}).get("subscriptionkore_id", ""),
external_id=result.get("metadata", {}).get("external_id", ""),
email=result.get("email", ""),
name=result.get("name"),
billing_address=address,
provider_refs=[provider_ref],
created_at=datetime.fromtimestamp(result["created"]),
)
# Subscription Operations
[docs]
async def create_subscriptionkore(
self,
request: CreateSubscriptionRequest,
customer_provider_ref: ProviderReference,
plan_provider_ref: ProviderReference,
) -> Subscription:
data: dict[str, Any] = {
"customer": customer_provider_ref.external_id,
"items[0][price]": plan_provider_ref.external_id,
"items[0][quantity]": str(request.quantity),
"metadata[subscriptionkore_customer_id]": request.customer_id,
"metadata[subscriptionkore_plan_id]": request.plan_id,
}
if request.trial_period_days:
data["trial_period_days"] = str(request.trial_period_days)
if request.coupon_code:
data["coupon"] = request.coupon_code
if request.payment_method_id:
data["default_payment_method"] = request.payment_method_id
for key, value in request.metadata.items():
data[f"metadata[{key}]"] = str(value)
result = await self._request("POST", "/subscriptionkores", data=data)
return self._map_subscriptionkore(result, request.customer_id, request.plan_id)
[docs]
async def update_subscriptionkore(
self,
request: UpdateSubscriptionRequest,
subscriptionkore_provider_ref: ProviderReference,
) -> Subscription:
data: dict[str, Any] = {}
if request.quantity is not None:
# Need to get the subscriptionkore first to find the item ID
sub = await self._request(
"GET",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
)
if sub.get("items", {}).get("data"):
item_id = sub["items"]["data"][0]["id"]
data[f"items[0][id]"] = item_id
data[f"items[0][quantity]"] = str(request.quantity)
if request.cancel_at_period_end is not None:
data["cancel_at_period_end"] = str(request.cancel_at_period_end).lower()
if request.metadata:
for key, value in request.metadata.items():
data[f"metadata[{key}]"] = str(value)
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data=data,
)
return self._map_subscriptionkore(result, "", "")
[docs]
async def cancel_subscriptionkore(
self,
subscriptionkore_provider_ref: ProviderReference,
immediate: bool = False,
) -> Subscription:
if immediate:
result = await self._request(
"DELETE",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
)
else:
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data={"cancel_at_period_end": "true"},
)
return self._map_subscriptionkore(result, "", "")
[docs]
async def pause_subscriptionkore(
self,
subscriptionkore_provider_ref: ProviderReference,
resumes_at: datetime | None = None,
) -> Subscription:
data: dict[str, Any] = {"pause_collection[behavior]": "void"}
if resumes_at:
data["pause_collection[resumes_at]"] = str(int(resumes_at.timestamp()))
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data=data,
)
return self._map_subscriptionkore(result, "", "")
[docs]
async def resume_subscriptionkore(
self,
subscriptionkore_provider_ref: ProviderReference,
) -> Subscription:
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data={"pause_collection": ""},
)
return self._map_subscriptionkore(result, "", "")
[docs]
async def get_subscriptionkore(
self,
provider_ref: ProviderReference,
) -> Subscription:
result = await self._request("GET", f"/subscriptionkores/{provider_ref.external_id}")
return self._map_subscriptionkore(result, "", "")
# Plan Change Operations
[docs]
async def change_plan(
self,
request: ChangePlanRequest,
subscriptionkore_provider_ref: ProviderReference,
new_plan_provider_ref: ProviderReference,
) -> Subscription:
# Get current subscriptionkore to find item ID
sub = await self._request(
"GET",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
)
item_id = sub["items"]["data"][0]["id"]
proration_map = {
ProrationBehavior.CREATE_PRORATIONS: "create_prorations",
ProrationBehavior.NONE: "none",
ProrationBehavior.ALWAYS_INVOICE: "always_invoice",
}
data: dict[str, Any] = {
"items[0][id]": item_id,
"items[0][price]": new_plan_provider_ref.external_id,
"proration_behavior": proration_map.get(
request.proration_behavior, "create_prorations"
),
}
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data=data,
)
return self._map_subscriptionkore(result, "", request.new_plan_id)
[docs]
async def preview_plan_change(
self,
request: ChangePlanRequest,
subscriptionkore_provider_ref: ProviderReference,
new_plan_provider_ref: ProviderReference,
) -> ChangePreview:
# Get current subscriptionkore
sub = await self._request(
"GET",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
)
item_id = sub["items"]["data"][0]["id"]
# Create upcoming invoice preview
result = await self._request(
"GET",
"/invoices/upcoming",
params={
"subscriptionkore": subscriptionkore_provider_ref.external_id,
"subscriptionkore_items[0][id]": item_id,
"subscriptionkore_items[0][price]": new_plan_provider_ref.external_id,
"subscriptionkore_proration_behavior": "create_prorations",
},
)
currency = Currency(result.get("currency", "usd").upper())
# Calculate proration from line items
proration_amount = Decimal("0")
credit_amount = Decimal("0")
for line in result.get("lines", {}).get("data", []):
if line.get("proration"):
amount = Decimal(line["amount"]) / 100
if amount < 0:
credit_amount += abs(amount)
else:
proration_amount += amount
return ChangePreview(
immediate_charge=Money.from_cents(result.get("amount_due", 0), currency),
next_invoice_amount=Money.from_cents(result.get("total", 0), currency),
proration_amount=Money(amount=proration_amount, currency=currency),
credit_amount=Money(amount=credit_amount, currency=currency),
next_billing_date=datetime.fromtimestamp(result.get("next_payment_attempt", 0)),
)
# Discount Operations
[docs]
async def apply_discount(
self,
subscriptionkore_provider_ref: ProviderReference,
discount: DiscountRequest,
) -> Subscription:
data: dict[str, Any] = {}
if discount.coupon_code:
data["coupon"] = discount.coupon_code
elif discount.promotion_code:
data["promotion_code"] = discount.promotion_code
result = await self._request(
"POST",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
data=data,
)
return self._map_subscriptionkore(result, "", "")
[docs]
async def remove_discount(
self,
subscriptionkore_provider_ref: ProviderReference,
) -> Subscription:
await self._request(
"DELETE",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}/discount",
)
result = await self._request(
"GET",
f"/subscriptionkores/{subscriptionkore_provider_ref.external_id}",
)
return self._map_subscriptionkore(result, "", "")
# Billing Operations
[docs]
async def get_invoice(self, provider_ref: ProviderReference) -> Invoice:
result = await self._request("GET", f"/invoices/{provider_ref.external_id}")
return self._map_invoice(result)
[docs]
async def list_invoices(
self,
customer_provider_ref: ProviderReference,
limit: int = 10,
starting_after: str | None = None,
) -> list[Invoice]:
params: dict[str, Any] = {
"customer": customer_provider_ref.external_id,
"limit": limit,
}
if starting_after:
params["starting_after"] = starting_after
result = await self._request("GET", "/invoices", params=params)
return [self._map_invoice(inv) for inv in result.get("data", [])]
[docs]
async def get_upcoming_invoice(
self,
subscriptionkore_provider_ref: ProviderReference,
) -> Invoice | None:
try:
result = await self._request(
"GET",
"/invoices/upcoming",
params={"subscriptionkore": subscriptionkore_provider_ref.external_id},
)
return self._map_invoice(result)
except ProviderAPIError as e:
if e.status_code == 404:
return None
raise
# Checkout / Portal
[docs]
async def create_checkout_session(
self,
request: CheckoutRequest,
plan_provider_ref: ProviderReference,
customer_provider_ref: ProviderReference | None = None,
) -> CheckoutSession:
data: dict[str, Any] = {
"mode": "subscriptionkore",
"success_url": request.success_url,
"cancel_url": request.cancel_url,
"line_items[0][price]": plan_provider_ref.external_id,
"line_items[0][quantity]": str(request.quantity),
}
if customer_provider_ref:
data["customer"] = customer_provider_ref.external_id
elif request.customer_email:
data["customer_email"] = request.customer_email
if request.trial_period_days:
data["subscriptionkore_data[trial_period_days]"] = str(request.trial_period_days)
if request.allow_promotion_codes:
data["allow_promotion_codes"] = "true"
for key, value in request.metadata.items():
data[f"metadata[{key}]"] = str(value)
result = await self._request("POST", "/checkout/sessions", data=data)
return CheckoutSession(
id=result["id"],
url=result["url"],
expires_at=datetime.fromtimestamp(result["expires_at"]),
)
[docs]
async def create_portal_session(
self,
customer_provider_ref: ProviderReference,
return_url: str,
) -> PortalSession:
result = await self._request(
"POST",
"/billing_portal/sessions",
data={
"customer": customer_provider_ref.external_id,
"return_url": return_url,
},
)
return PortalSession(
id=result["id"],
url=result["url"],
)
# Webhook Handling
[docs]
async def verify_webhook(
self,
payload: bytes,
headers: dict[str, str],
) -> bool:
"""Verify Stripe webhook signature."""
import hashlib
import hmac
import time
signature_header = headers.get("stripe-signature", "")
if not signature_header:
return False
# Parse the signature header
elements = dict(item.split("=", 1) for item in signature_header.split(",") if "=" in item)
timestamp = elements.get("t", "")
signatures = [v for k, v in elements.items() if k.startswith("v1")]
if not timestamp or not signatures:
return False
# Check timestamp (within 5 minutes)
try:
ts = int(timestamp)
if abs(time.time() - ts) > 300:
return False
except ValueError:
return False
# Compute expected signature
signed_payload = f"{timestamp}. {payload.decode('utf-8')}"
expected_sig = hmac.new(
self._config.webhook_secret.encode("utf-8"),
signed_payload.encode("utf-8"),
hashlib.sha256,
).hexdigest()
# Compare signatures
return any(hmac.compare_digest(expected_sig, sig) for sig in signatures)
[docs]
async def parse_webhook(
self,
payload: bytes,
headers: dict[str, str],
) -> ProviderWebhookEvent:
"""Parse Stripe webhook payload."""
import json
data = json.loads(payload)
return ProviderWebhookEvent(
provider=ProviderType.STRIPE,
event_id=data["id"],
event_type=data["type"],
occurred_at=datetime.fromtimestamp(data["created"]),
data=data.get("data", {}),
raw_payload=payload,
)
# Sync Operations
[docs]
async def sync_products(self) -> list[Product]:
result = await self._request("GET", "/products", params={"limit": 100, "active": "true"})
products = []
for prod in result.get("data", []):
products.append(
Product(
id="", # Will be assigned by repository
name=prod["name"],
description=prod.get("description"),
provider_refs=[
ProviderReference(
provider=ProviderType.STRIPE,
external_id=prod["id"],
)
],
active=prod["active"],
metadata=prod.get("metadata", {}),
created_at=datetime.fromtimestamp(prod["created"]),
)
)
return products
[docs]
async def sync_plans(self, product_provider_ref: ProviderReference) -> list[Plan]:
result = await self._request(
"GET",
"/prices",
params={
"product": product_provider_ref.external_id,
"limit": 100,
"active": "true",
},
)
plans = []
for price in result.get("data", []):
if price.get("type") != "recurring":
continue
recurring = price.get("recurring", {})
interval = recurring.get("interval", "month")
interval_count = recurring.get("interval_count", 1)
plans.append(
Plan(
id="", # Will be assigned by repository
product_id="", # Will be resolved
name=price.get("nickname") or price["id"],
provider_refs=[
ProviderReference(
provider=ProviderType.STRIPE,
external_id=price["id"],
)
],
price=Money.from_cents(
price.get("unit_amount", 0),
Currency(price.get("currency", "usd").upper()),
),
billing_period=BillingPeriod(
interval=Interval(interval),
interval_count=interval_count,
),
trial_period_days=recurring.get("trial_period_days"),
active=price["active"],
metadata=price.get("metadata", {}),
created_at=datetime.fromtimestamp(price["created"]),
)
)
return plans
# Mapping Helpers
def _map_subscriptionkore(
self,
data: dict[str, Any],
customer_id: str,
plan_id: str,
) -> Subscription:
"""Map Stripe subscriptionkore to domain model."""
status = self.STATUS_MAP.get(data["status"], SubscriptionStatus.ACTIVE)
# Handle pause collection
pause_config = None
pause_data = data.get("pause_collection")
if pause_data:
resumes_at = None
if pause_data.get("resumes_at"):
resumes_at = datetime.fromtimestamp(pause_data["resumes_at"])
pause_config = PauseConfig(
resumes_at=resumes_at,
behavior=PauseBehavior.VOID,
)
status = SubscriptionStatus.PAUSED
# Handle discount
discount = None
discount_data = data.get("discount")
if discount_data:
coupon = discount_data.get("coupon", {})
discount = AppliedDiscount(
discount_id=discount_data["id"],
coupon_code=coupon.get("id"),
amount_off=Money.from_cents(coupon["amount_off"], Currency.USD)
if coupon.get("amount_off")
else None,
percent_off=Decimal(str(coupon["percent_off"]))
if coupon.get("percent_off")
else None,
valid_until=datetime.fromtimestamp(discount_data["end"])
if discount_data.get("end")
else None,
)
# Get quantity from items
quantity = 1
items = data.get("items", {}).get("data", [])
if items:
quantity = items[0].get("quantity", 1)
# Get trial end
trial_end = None
if data.get("trial_end"):
trial_end = datetime.fromtimestamp(data["trial_end"])
# Use metadata for internal IDs if not provided
metadata = data.get("metadata", {})
resolved_customer_id = customer_id or metadata.get("subscriptionkore_customer_id", "")
resolved_plan_id = plan_id or metadata.get("subscriptionkore_plan_id", "")
return Subscription(
id="", # Will be assigned
customer_id=resolved_customer_id,
plan_id=resolved_plan_id,
provider_ref=ProviderReference(
provider=ProviderType.STRIPE,
external_id=data["id"],
metadata={"customer": data.get("customer")},
),
status=status,
current_period=DateRange(
start=datetime.fromtimestamp(data["current_period_start"]),
end=datetime.fromtimestamp(data["current_period_end"]),
),
trial_end=trial_end,
cancel_at_period_end=data.get("cancel_at_period_end", False),
canceled_at=datetime.fromtimestamp(data["canceled_at"])
if data.get("canceled_at")
else None,
ended_at=datetime.fromtimestamp(data["ended_at"]) if data.get("ended_at") else None,
pause_collection=pause_config,
discount=discount,
quantity=quantity,
metadata=metadata,
created_at=datetime.fromtimestamp(data["created"]),
)
def _map_invoice(self, data: dict[str, Any]) -> Invoice:
"""Map Stripe invoice to domain model."""
currency = Currency(data.get("currency", "usd").upper())
status = self.INVOICE_STATUS_MAP.get(data.get("status", "draft"), InvoiceStatus.DRAFT)
# Map line items
line_items = []
for line in data.get("lines", {}).get("data", []):
period = None
if line.get("period"):
period = DateRange(
start=datetime.fromtimestamp(line["period"]["start"]),
end=datetime.fromtimestamp(line["period"]["end"]),
)
line_items.append(
InvoiceLineItem(
id=line["id"],
description=line.get("description", ""),
quantity=line.get("quantity", 1),
unit_amount=Money.from_cents(line.get("unit_amount", 0), currency),
amount=Money.from_cents(line.get("amount", 0), currency),
period=period,
proration=line.get("proration", False),
)
)
# Build period from line items or invoice dates
period = None
if data.get("period_start") and data.get("period_end"):
period = DateRange(
start=datetime.fromtimestamp(data["period_start"]),
end=datetime.fromtimestamp(data["period_end"]),
)
return Invoice(
id="", # Will be assigned
customer_id="", # Will be resolved
subscriptionkore_id=data.get("subscriptionkore"),
provider_ref=ProviderReference(
provider=ProviderType.STRIPE,
external_id=data.get("id", ""),
),
status=status,
subtotal=Money.from_cents(data.get("subtotal", 0), currency),
tax=Money.from_cents(data.get("tax", 0) or 0, currency),
discount_amount=Money.from_cents(
data.get("total_discount_amounts", [{}])[0].get("amount", 0)
if data.get("total_discount_amounts")
else 0,
currency,
),
total=Money.from_cents(data.get("total", 0), currency),
amount_paid=Money.from_cents(data.get("amount_paid", 0), currency),
amount_due=Money.from_cents(data.get("amount_due", 0), currency),
currency=currency,
line_items=line_items,
period=period,
due_date=datetime.fromtimestamp(data["due_date"]) if data.get("due_date") else None,
paid_at=datetime.fromtimestamp(data["status_transitions"]["paid_at"])
if data.get("status_transitions", {}).get("paid_at")
else None,
invoice_pdf_url=data.get("invoice_pdf"),
hosted_invoice_url=data.get("hosted_invoice_url"),
metadata=data.get("metadata", {}),
created_at=datetime.fromtimestamp(data["created"]),
)