Coverage for src / documint_mcp / db.py: 0%
347 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 22:30 -0400
1"""SQLAlchemy persistence layer for the Documint control plane."""
3from __future__ import annotations
5import atexit
6from collections.abc import Iterator
7from contextlib import contextmanager
8from datetime import UTC, datetime
10from sqlalchemy import (
11 JSON,
12 DateTime,
13 Float,
14 ForeignKey,
15 Index,
16 Integer,
17 String,
18 Text,
19 UniqueConstraint,
20 create_engine,
21)
22from sqlalchemy.engine import Engine
23from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
25from .config import settings
28def utcnow() -> datetime:
29 return datetime.now(tz=UTC)
32class Base(DeclarativeBase):
33 """Base ORM model."""
36class UserRecord(Base):
37 __tablename__ = "users"
39 id: Mapped[str] = mapped_column(String(64), primary_key=True)
40 external_id: Mapped[str | None] = mapped_column(String(128), unique=True)
41 email: Mapped[str | None] = mapped_column(String(320))
42 name: Mapped[str] = mapped_column(String(255))
43 provider: Mapped[str] = mapped_column(String(32), default="internal")
44 created_at: Mapped[datetime] = mapped_column(
45 DateTime(timezone=True), default=utcnow
46 )
47 updated_at: Mapped[datetime] = mapped_column(
48 DateTime(timezone=True), default=utcnow, onupdate=utcnow
49 )
52class WorkspaceRecord(Base):
53 __tablename__ = "workspaces"
55 id: Mapped[str] = mapped_column(String(64), primary_key=True)
56 slug: Mapped[str] = mapped_column(String(128), unique=True, index=True)
57 name: Mapped[str] = mapped_column(String(255))
58 description: Mapped[str] = mapped_column(Text, default="")
59 created_at: Mapped[datetime] = mapped_column(
60 DateTime(timezone=True), default=utcnow
61 )
62 updated_at: Mapped[datetime] = mapped_column(
63 DateTime(timezone=True), default=utcnow, onupdate=utcnow
64 )
67class WorkspaceMemberRecord(Base):
68 __tablename__ = "workspace_members"
69 __table_args__ = (
70 UniqueConstraint("workspace_id", "user_id", name="uq_workspace_member"),
71 )
73 id: Mapped[str] = mapped_column(String(64), primary_key=True)
74 workspace_id: Mapped[str] = mapped_column(
75 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
76 )
77 user_id: Mapped[str] = mapped_column(
78 ForeignKey("users.id", ondelete="CASCADE"), index=True
79 )
80 role: Mapped[str] = mapped_column(String(32), default="owner")
81 created_at: Mapped[datetime] = mapped_column(
82 DateTime(timezone=True), default=utcnow
83 )
86class ApiTokenRecord(Base):
87 __tablename__ = "api_tokens"
89 id: Mapped[str] = mapped_column(String(64), primary_key=True)
90 workspace_id: Mapped[str] = mapped_column(
91 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
92 )
93 user_id: Mapped[str] = mapped_column(
94 ForeignKey("users.id", ondelete="CASCADE"), index=True
95 )
96 label: Mapped[str] = mapped_column(String(255))
97 token_prefix: Mapped[str] = mapped_column(String(24), index=True)
98 token_hash: Mapped[str] = mapped_column(String(255), unique=True)
99 scopes: Mapped[list[str]] = mapped_column(JSON, default=list)
100 created_at: Mapped[datetime] = mapped_column(
101 DateTime(timezone=True), default=utcnow
102 )
103 last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
104 revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
107class GitHubInstallationRecord(Base):
108 __tablename__ = "github_installations"
110 id: Mapped[str] = mapped_column(String(64), primary_key=True)
111 workspace_id: Mapped[str] = mapped_column(
112 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
113 )
114 installation_id: Mapped[str | None] = mapped_column(String(64), unique=True)
115 account_login: Mapped[str | None] = mapped_column(String(255))
116 account_type: Mapped[str | None] = mapped_column(String(64))
117 repository_selection: Mapped[str] = mapped_column(String(32), default="selected")
118 metadata_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
119 created_at: Mapped[datetime] = mapped_column(
120 DateTime(timezone=True), default=utcnow
121 )
122 updated_at: Mapped[datetime] = mapped_column(
123 DateTime(timezone=True), default=utcnow, onupdate=utcnow
124 )
127class GitHubRepositoryRecord(Base):
128 __tablename__ = "github_repositories"
129 __table_args__ = (
130 UniqueConstraint(
131 "github_installation_id",
132 "full_name",
133 name="uq_installation_repository_full_name",
134 ),
135 )
137 id: Mapped[str] = mapped_column(String(64), primary_key=True)
138 github_installation_id: Mapped[str] = mapped_column(
139 ForeignKey("github_installations.id", ondelete="CASCADE"), index=True
140 )
141 repository_id: Mapped[str | None] = mapped_column(String(64), index=True)
142 full_name: Mapped[str] = mapped_column(String(255), index=True)
143 owner: Mapped[str] = mapped_column(String(255))
144 name: Mapped[str] = mapped_column(String(255))
145 default_branch: Mapped[str] = mapped_column(String(255), default="main")
146 visibility: Mapped[str] = mapped_column(String(32), default="private")
147 is_private: Mapped[bool] = mapped_column(default=True)
148 is_archived: Mapped[bool] = mapped_column(default=False)
149 metadata_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
150 created_at: Mapped[datetime] = mapped_column(
151 DateTime(timezone=True), default=utcnow
152 )
153 updated_at: Mapped[datetime] = mapped_column(
154 DateTime(timezone=True), default=utcnow, onupdate=utcnow
155 )
158class ProjectRecord(Base):
159 __tablename__ = "projects"
160 __table_args__ = (
161 UniqueConstraint("workspace_id", "slug", name="uq_project_workspace_slug"),
162 )
164 id: Mapped[str] = mapped_column(String(64), primary_key=True)
165 workspace_id: Mapped[str] = mapped_column(
166 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
167 )
168 github_installation_id: Mapped[str | None] = mapped_column(
169 ForeignKey("github_installations.id", ondelete="SET NULL"), index=True
170 )
171 name: Mapped[str] = mapped_column(String(255))
172 slug: Mapped[str] = mapped_column(String(128), index=True)
173 description: Mapped[str] = mapped_column(Text, default="")
174 public_url: Mapped[str] = mapped_column(String(1024))
175 dashboard_url: Mapped[str] = mapped_column(String(1024))
176 source_provider: Mapped[str] = mapped_column(String(32), default="github")
177 onboarding_status: Mapped[str] = mapped_column(String(32), default="connected")
178 claude_md_content: Mapped[str | None] = mapped_column(Text, nullable=True)
179 agents_md_content: Mapped[str | None] = mapped_column(Text, nullable=True)
180 llms_txt_content: Mapped[str | None] = mapped_column(Text, nullable=True)
181 llms_full_txt_content: Mapped[str | None] = mapped_column(Text, nullable=True)
182 last_scanned_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
183 created_at: Mapped[datetime] = mapped_column(
184 DateTime(timezone=True), default=utcnow
185 )
186 updated_at: Mapped[datetime] = mapped_column(
187 DateTime(timezone=True), default=utcnow, onupdate=utcnow
188 )
191class RepositorySourceRecord(Base):
192 __tablename__ = "repository_sources"
194 id: Mapped[str] = mapped_column(String(64), primary_key=True)
195 project_id: Mapped[str] = mapped_column(
196 ForeignKey("projects.id", ondelete="CASCADE"), unique=True, index=True
197 )
198 provider: Mapped[str] = mapped_column(String(32), default="github")
199 owner: Mapped[str] = mapped_column(String(255))
200 repo: Mapped[str] = mapped_column(String(255))
201 default_branch: Mapped[str] = mapped_column(String(255))
202 local_path: Mapped[str] = mapped_column(Text)
203 current_ref: Mapped[str] = mapped_column(String(255))
204 docs_root: Mapped[str] = mapped_column(Text)
205 installation_id: Mapped[str | None] = mapped_column(String(64))
206 created_at: Mapped[datetime] = mapped_column(
207 DateTime(timezone=True), default=utcnow
208 )
209 updated_at: Mapped[datetime] = mapped_column(
210 DateTime(timezone=True), default=utcnow, onupdate=utcnow
211 )
214class ProjectSettingsRecord(Base):
215 __tablename__ = "project_settings"
217 id: Mapped[str] = mapped_column(String(64), primary_key=True)
218 project_id: Mapped[str] = mapped_column(
219 ForeignKey("projects.id", ondelete="CASCADE"), unique=True, index=True
220 )
221 docs_root: Mapped[str] = mapped_column(Text)
222 config_version: Mapped[int] = mapped_column(Integer, default=1)
223 config_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
224 ai_policy: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
225 publish_behavior: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
226 pr_behavior: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
227 updated_at: Mapped[datetime] = mapped_column(
228 DateTime(timezone=True), default=utcnow, onupdate=utcnow
229 )
232class ArtifactDefinitionRecord(Base):
233 __tablename__ = "artifact_definitions"
234 __table_args__ = (
235 UniqueConstraint("project_id", "artifact_key", name="uq_project_artifact_key"),
236 )
238 id: Mapped[str] = mapped_column(String(128), primary_key=True)
239 project_id: Mapped[str] = mapped_column(
240 ForeignKey("projects.id", ondelete="CASCADE"), index=True
241 )
242 artifact_key: Mapped[str] = mapped_column(String(128))
243 slug: Mapped[str] = mapped_column(String(128))
244 title: Mapped[str] = mapped_column(String(255))
245 artifact_type: Mapped[str] = mapped_column(String(64))
246 summary: Mapped[str] = mapped_column(Text)
247 doc_paths: Mapped[list[str]] = mapped_column(JSON, default=list)
248 source_patterns: Mapped[list[str]] = mapped_column(JSON, default=list)
249 symbol_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
250 symbols_json: Mapped[str | None] = mapped_column(Text, nullable=True)
251 created_at: Mapped[datetime] = mapped_column(
252 DateTime(timezone=True), default=utcnow
253 )
254 updated_at: Mapped[datetime] = mapped_column(
255 DateTime(timezone=True), default=utcnow, onupdate=utcnow
256 )
259class SourceSignalRecord(Base):
260 __tablename__ = "source_signals"
262 id: Mapped[str] = mapped_column(String(64), primary_key=True)
263 project_id: Mapped[str] = mapped_column(
264 ForeignKey("projects.id", ondelete="CASCADE"), index=True
265 )
266 signal_type: Mapped[str] = mapped_column(String(32))
267 ref: Mapped[str] = mapped_column(String(255))
268 changed_files: Mapped[list[str]] = mapped_column(JSON, default=list)
269 created_at: Mapped[datetime] = mapped_column(
270 DateTime(timezone=True), default=utcnow
271 )
274class VerificationRunRecord(Base):
275 __tablename__ = "verification_runs"
277 id: Mapped[str] = mapped_column(String(64), primary_key=True)
278 project_id: Mapped[str] = mapped_column(
279 ForeignKey("projects.id", ondelete="CASCADE"), index=True
280 )
281 signal_id: Mapped[str] = mapped_column(
282 ForeignKey("source_signals.id", ondelete="CASCADE"), index=True
283 )
284 status: Mapped[str] = mapped_column(String(32), default="completed")
285 findings_count: Mapped[int] = mapped_column(Integer, default=0)
286 started_at: Mapped[datetime] = mapped_column(
287 DateTime(timezone=True), default=utcnow
288 )
289 completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
292class ArtifactTraceRecord(Base):
293 __tablename__ = "artifact_traces"
294 __table_args__ = (
295 UniqueConstraint(
296 "project_id", "artifact_definition_id", name="uq_project_trace"
297 ),
298 )
300 id: Mapped[str] = mapped_column(String(128), primary_key=True)
301 project_id: Mapped[str] = mapped_column(
302 ForeignKey("projects.id", ondelete="CASCADE"), index=True
303 )
304 artifact_definition_id: Mapped[str] = mapped_column(
305 ForeignKey("artifact_definitions.id", ondelete="CASCADE"), index=True
306 )
307 doc_paths: Mapped[list[str]] = mapped_column(JSON, default=list)
308 source_paths: Mapped[list[str]] = mapped_column(JSON, default=list)
309 latest_source_revision: Mapped[dict[str, object] | None] = mapped_column(JSON)
310 latest_doc_revision: Mapped[dict[str, object] | None] = mapped_column(JSON)
311 verification_status: Mapped[str] = mapped_column(String(32))
312 updated_at: Mapped[datetime] = mapped_column(
313 DateTime(timezone=True), default=utcnow, onupdate=utcnow
314 )
317class DriftFindingRecord(Base):
318 __tablename__ = "drift_findings"
320 id: Mapped[str] = mapped_column(String(64), primary_key=True)
321 project_id: Mapped[str] = mapped_column(
322 ForeignKey("projects.id", ondelete="CASCADE"), index=True
323 )
324 verification_run_id: Mapped[str | None] = mapped_column(
325 ForeignKey("verification_runs.id", ondelete="SET NULL"), index=True
326 )
327 artifact_key: Mapped[str] = mapped_column(String(128), index=True)
328 artifact_type: Mapped[str] = mapped_column(String(64))
329 severity: Mapped[str] = mapped_column(String(32))
330 summary: Mapped[str] = mapped_column(Text)
331 rationale: Mapped[str] = mapped_column(Text)
332 source_paths: Mapped[list[str]] = mapped_column(JSON, default=list)
333 doc_paths: Mapped[list[str]] = mapped_column(JSON, default=list)
334 suggested_actions: Mapped[list[str]] = mapped_column(JSON, default=list)
335 source_revision: Mapped[dict[str, object] | None] = mapped_column(JSON)
336 doc_revision: Mapped[dict[str, object] | None] = mapped_column(JSON)
337 status: Mapped[str] = mapped_column(String(32), default="open", index=True)
338 changed_symbols: Mapped[list[dict[str, object]]] = mapped_column(
339 JSON, default=list, nullable=True
340 )
341 created_at: Mapped[datetime] = mapped_column(
342 DateTime(timezone=True), default=utcnow
343 )
344 updated_at: Mapped[datetime] = mapped_column(
345 DateTime(timezone=True), default=utcnow, onupdate=utcnow
346 )
349class AgentRunRecord(Base):
350 __tablename__ = "agent_runs"
352 id: Mapped[str] = mapped_column(String(64), primary_key=True)
353 project_id: Mapped[str] = mapped_column(
354 ForeignKey("projects.id", ondelete="CASCADE"), index=True
355 )
356 kind: Mapped[str] = mapped_column(String(64))
357 provider: Mapped[str] = mapped_column(String(64))
358 model: Mapped[str | None] = mapped_column(String(255))
359 status: Mapped[str] = mapped_column(String(32), default="completed")
360 input_summary: Mapped[str] = mapped_column(Text)
361 output_summary: Mapped[str] = mapped_column(Text)
362 metadata_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
363 created_at: Mapped[datetime] = mapped_column(
364 DateTime(timezone=True), default=utcnow
365 )
366 completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
369class DocPatchRecord(Base):
370 __tablename__ = "doc_patches"
372 id: Mapped[str] = mapped_column(String(64), primary_key=True)
373 project_id: Mapped[str] = mapped_column(
374 ForeignKey("projects.id", ondelete="CASCADE"), index=True
375 )
376 finding_id: Mapped[str | None] = mapped_column(
377 ForeignKey("drift_findings.id", ondelete="SET NULL"), index=True
378 )
379 artifact_key: Mapped[str] = mapped_column(String(128), index=True)
380 target_path: Mapped[str] = mapped_column(Text)
381 summary: Mapped[str] = mapped_column(Text)
382 rationale: Mapped[str] = mapped_column(Text, default="")
383 proposed_sections: Mapped[list[str]] = mapped_column(JSON, default=list)
384 citations: Mapped[list[dict[str, object]]] = mapped_column(JSON, default=list)
385 preview_markdown: Mapped[str] = mapped_column(Text)
386 ai_provider: Mapped[str] = mapped_column(String(64), default="deterministic")
387 model_name: Mapped[str | None] = mapped_column(String(255))
388 chain_steps_used: Mapped[int | None] = mapped_column(Integer, nullable=True)
389 confidence_score: Mapped[float | None] = mapped_column(Float, nullable=True)
390 status: Mapped[str] = mapped_column(String(32), default="draft")
391 created_at: Mapped[datetime] = mapped_column(
392 DateTime(timezone=True), default=utcnow
393 )
394 updated_at: Mapped[datetime] = mapped_column(
395 DateTime(timezone=True), default=utcnow, onupdate=utcnow
396 )
399class PullRequestRecord(Base):
400 __tablename__ = "pull_requests"
402 id: Mapped[str] = mapped_column(String(64), primary_key=True)
403 project_id: Mapped[str] = mapped_column(
404 ForeignKey("projects.id", ondelete="CASCADE"), index=True
405 )
406 patch_id: Mapped[str] = mapped_column(
407 ForeignKey("doc_patches.id", ondelete="CASCADE"), unique=True, index=True
408 )
409 branch_name: Mapped[str] = mapped_column(String(255))
410 title: Mapped[str] = mapped_column(String(255))
411 url: Mapped[str] = mapped_column(String(1024))
412 state: Mapped[str] = mapped_column(String(32), default="open")
413 created_at: Mapped[datetime] = mapped_column(
414 DateTime(timezone=True), default=utcnow
415 )
416 updated_at: Mapped[datetime] = mapped_column(
417 DateTime(timezone=True), default=utcnow, onupdate=utcnow
418 )
421class PublishDeploymentRecord(Base):
422 __tablename__ = "publish_deployments"
424 id: Mapped[str] = mapped_column(String(64), primary_key=True)
425 project_id: Mapped[str] = mapped_column(
426 ForeignKey("projects.id", ondelete="CASCADE"), index=True
427 )
428 status: Mapped[str] = mapped_column(String(32), default="completed")
429 commit_ref: Mapped[str] = mapped_column(String(255))
430 site_url: Mapped[str] = mapped_column(String(1024))
431 preview_url: Mapped[str] = mapped_column(String(1024))
432 docs_count: Mapped[int] = mapped_column(Integer, default=0)
433 generated_at: Mapped[datetime] = mapped_column(
434 DateTime(timezone=True), default=utcnow
435 )
438class BackgroundJobRecord(Base):
439 __tablename__ = "background_jobs"
441 id: Mapped[str] = mapped_column(String(64), primary_key=True)
442 workspace_id: Mapped[str | None] = mapped_column(
443 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
444 )
445 project_id: Mapped[str | None] = mapped_column(
446 ForeignKey("projects.id", ondelete="CASCADE"), index=True
447 )
448 job_kind: Mapped[str] = mapped_column(String(64), index=True)
449 status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
450 attempt_count: Mapped[int] = mapped_column(Integer, default=0)
451 error_summary: Mapped[str | None] = mapped_column(Text)
452 resource_type: Mapped[str | None] = mapped_column(String(64))
453 resource_id: Mapped[str | None] = mapped_column(String(128))
454 result_summary: Mapped[str | None] = mapped_column(Text)
455 payload_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
456 result_json: Mapped[dict[str, object] | None] = mapped_column(JSON)
457 created_at: Mapped[datetime] = mapped_column(
458 DateTime(timezone=True), default=utcnow
459 )
460 started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
461 completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
462 updated_at: Mapped[datetime] = mapped_column(
463 DateTime(timezone=True), default=utcnow, onupdate=utcnow
464 )
467class PublishedPageRecord(Base):
468 __tablename__ = "published_pages"
469 __table_args__ = (
470 UniqueConstraint("project_id", "path", name="uq_project_published_path"),
471 Index(
472 "ix_published_pages_workspace_project_path",
473 "workspace_slug",
474 "project_slug",
475 "path",
476 ),
477 )
479 id: Mapped[str] = mapped_column(String(64), primary_key=True)
480 project_id: Mapped[str] = mapped_column(
481 ForeignKey("projects.id", ondelete="CASCADE"), index=True
482 )
483 deployment_id: Mapped[str] = mapped_column(
484 ForeignKey("publish_deployments.id", ondelete="CASCADE"), index=True
485 )
486 workspace_slug: Mapped[str] = mapped_column(String(128), index=True)
487 project_slug: Mapped[str] = mapped_column(String(128), index=True)
488 path: Mapped[str] = mapped_column(String(255))
489 title: Mapped[str] = mapped_column(String(255))
490 description: Mapped[str] = mapped_column(Text, default="")
491 content_markdown: Mapped[str] = mapped_column(Text)
492 search_body: Mapped[str] = mapped_column(Text)
493 source_path: Mapped[str] = mapped_column(Text)
494 created_at: Mapped[datetime] = mapped_column(
495 DateTime(timezone=True), default=utcnow
496 )
497 updated_at: Mapped[datetime] = mapped_column(
498 DateTime(timezone=True), default=utcnow, onupdate=utcnow
499 )
502class ActivityEventRecord(Base):
503 __tablename__ = "activity_events"
505 id: Mapped[str] = mapped_column(String(64), primary_key=True)
506 workspace_id: Mapped[str] = mapped_column(
507 ForeignKey("workspaces.id", ondelete="CASCADE"), index=True
508 )
509 project_id: Mapped[str | None] = mapped_column(
510 ForeignKey("projects.id", ondelete="CASCADE"), index=True
511 )
512 kind: Mapped[str] = mapped_column(String(64), index=True)
513 title: Mapped[str] = mapped_column(String(255))
514 body: Mapped[str] = mapped_column(Text, default="")
515 metadata_json: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
516 created_at: Mapped[datetime] = mapped_column(
517 DateTime(timezone=True), default=utcnow
518 )
521class GitHubWebhookDeliveryRecord(Base):
522 __tablename__ = "github_webhook_deliveries"
524 id: Mapped[str] = mapped_column(String(64), primary_key=True)
525 delivery_id: Mapped[str | None] = mapped_column(String(128), unique=True)
526 event_name: Mapped[str] = mapped_column(String(64))
527 repository: Mapped[str | None] = mapped_column(String(255))
528 action: Mapped[str | None] = mapped_column(String(128))
529 ref: Mapped[str | None] = mapped_column(String(255))
530 status: Mapped[str] = mapped_column(String(32), default="received")
531 payload: Mapped[dict[str, object]] = mapped_column(JSON, default=dict)
532 created_at: Mapped[datetime] = mapped_column(
533 DateTime(timezone=True), default=utcnow
534 )
535 processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
538class WaitlistSignupRecord(Base):
539 __tablename__ = "waitlist_signups"
541 id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
542 email: Mapped[str] = mapped_column(String(320), unique=True)
543 source: Mapped[str] = mapped_column(String(64), default="landing")
544 created_at: Mapped[datetime] = mapped_column(
545 DateTime(timezone=True), default=utcnow
546 )
549class EarlyAccessTokenRecord(Base):
550 __tablename__ = "early_access_tokens"
552 id: Mapped[str] = mapped_column(String(64), primary_key=True)
553 token: Mapped[str] = mapped_column(String(64), unique=True, index=True)
554 email: Mapped[str] = mapped_column(String(320), index=True)
555 created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
556 activated_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
557 source: Mapped[str] = mapped_column(String(64), default="brain_teaser")
560_engine: Engine | None = None
561_session_factory: sessionmaker[Session] | None = None
562_engine_url: str | None = None
565def _sqlite_connect_args(url: str) -> dict[str, object]:
566 if url.startswith("sqlite"):
567 return {"check_same_thread": False}
568 return {}
571def get_engine() -> Engine:
572 global _engine, _session_factory, _engine_url
573 database_url = settings.database_url
574 if _engine is None or _engine_url != database_url:
575 _engine = create_engine(
576 database_url,
577 future=True,
578 pool_pre_ping=True,
579 connect_args=_sqlite_connect_args(database_url),
580 )
581 _session_factory = sessionmaker(_engine, expire_on_commit=False)
582 _engine_url = database_url
583 return _engine
586def init_db() -> None:
587 if settings.auto_create_schema:
588 Base.metadata.create_all(get_engine())
591def reset_db() -> None:
592 global _engine, _session_factory, _engine_url
593 if _engine is not None:
594 _engine.dispose()
595 _engine = None
596 _session_factory = None
597 _engine_url = None
600@contextmanager
601def session_scope() -> Iterator[Session]:
602 if settings.auto_create_schema:
603 init_db()
604 if _session_factory is None:
605 get_engine()
606 assert _session_factory is not None
607 session = _session_factory()
608 try:
609 yield session
610 session.commit()
611 except Exception:
612 session.rollback()
613 raise
614 finally:
615 session.close()
618atexit.register(reset_db)