Coverage for fss\middleware\db_session_middleware.py: 82%
68 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-12 22:20 +0800
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-12 22:20 +0800
1"""Session proxy used in the project"""
3from contextvars import ContextVar
4from typing import Dict, Optional, Union
6from sqlalchemy.engine import Engine
7from sqlalchemy.engine.url import URL
8from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
9from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
10from starlette.requests import Request
11from starlette.types import ASGIApp
14try:
15 from sqlalchemy.ext.asyncio import async_sessionmaker
16except ImportError:
17 from sqlalchemy.orm import sessionmaker as async_sessionmaker
20def create_middleware_and_session_proxy():
21 _Session: Optional[async_sessionmaker] = None
22 # Usage of context vars inside closures is not recommended, since they are not properly
23 # garbage collected, but in our use case context var is created on program startup and
24 # is used throughout the whole its lifecycle.
25 _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
27 class SQLAlchemyMiddleware(BaseHTTPMiddleware):
28 def __init__(
29 self,
30 app: ASGIApp,
31 db_url: Optional[Union[str, URL]] = None,
32 custom_engine: Optional[Engine] = None,
33 engine_args: Dict = None,
34 session_args: Dict = None,
35 commit_on_exit: bool = True,
36 ):
37 super().__init__(app)
38 self.commit_on_exit = commit_on_exit
39 engine_args = engine_args or {}
40 session_args = session_args or {}
42 if not custom_engine and not db_url:
43 raise ValueError(
44 "You need to pass a db_url or a custom_engine parameter."
45 )
46 if not custom_engine:
47 engine = create_async_engine(db_url, **engine_args)
48 else:
49 engine = custom_engine
51 nonlocal _Session
52 _Session = async_sessionmaker(
53 engine, class_=AsyncSession, expire_on_commit=False, **session_args
54 )
56 async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
57 async with DBSession(commit_on_exit=self.commit_on_exit):
58 return await call_next(request)
60 class DBSessionMeta(type):
61 @property
62 def session(self) -> AsyncSession:
63 """Return an instance of Session local to the current async context."""
64 if _Session is None:
65 raise SessionNotInitialisedException
67 session = _session.get()
68 if session is None:
69 raise MissingSessionException
71 return session
73 class DBSession(metaclass=DBSessionMeta):
74 def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
75 self.token = None
76 self.session_args = session_args or {}
77 self.commit_on_exit = commit_on_exit
79 async def __aenter__(self):
80 if not isinstance(_Session, async_sessionmaker):
81 raise SessionNotInitialisedException
83 self.token = _session.set(_Session(**self.session_args)) # type: ignore
84 return type(self)
86 async def __aexit__(self, exc_type, exc_value, traceback):
87 session = _session.get()
89 try:
90 if exc_type is not None:
91 await session.rollback()
92 elif (
93 self.commit_on_exit
94 ): # Note: Changed this to elif to avoid commit after rollback
95 await session.commit()
96 finally:
97 await session.close()
98 _session.reset(self.token)
100 return SQLAlchemyMiddleware, DBSession
103SQLAlchemyMiddleware, db = create_middleware_and_session_proxy()
106class MissingSessionException(Exception):
107 """
108 Exception raised for when the user tries to access a database session before it is created.
109 """
111 def __init__(self):
112 detail = """
113 No session found! Either you are not currently in a request context,
114 or you need to manually create a session context by using a `db` instance as
115 a context manager e.g.:
117 async with db():
118 await db.session.execute(foo.select()).fetchall()
119 """
121 super().__init__(detail)
124class SessionNotInitialisedException(Exception):
125 """
126 Exception raised when the user creates a new DB session without first initialising it.
127 """
129 def __init__(self):
130 detail = """
131 Session not initialised! Ensure that DBSessionMiddleware has been initialised before
132 attempting database access.
133 """
135 super().__init__(detail)