Coverage for fss\common\persistence\sqlmodel_impl.py: 51%
114 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 15:26 +0800
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 15:26 +0800
1"""Sqlmodel impl that do database operations"""
3from typing import Generic, TypeVar, List, Any, Type, Union
5from fastapi_pagination.ext.sqlmodel import paginate
6from pydantic import BaseModel
7from sqlmodel import SQLModel, select, func, insert, update, delete
8from sqlmodel.ext.asyncio.session import AsyncSession
10from fss.common.enum.enum import SortEnum
11from fss.common.persistence.base_mapper import BaseMapper
12from fss.middleware.db_session_middleware import db
14ModelType = TypeVar("ModelType", bound=SQLModel)
15CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
16UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
17SchemaType = TypeVar("SchemaType", bound=BaseModel)
18T = TypeVar("T", bound=SQLModel)
21class SqlModelMapper(Generic[ModelType], BaseMapper):
22 def __init__(self, model: Type[ModelType]):
23 self.model = model
24 self.db = db
26 def get_db_session(self) -> Type[Any]:
27 return self.db
29 async def insert(
30 self,
31 *,
32 data: Union[ModelType, SchemaType],
33 db_session: Union[AsyncSession, None] = None,
34 ) -> int:
35 db_session = db_session or self.db.session
36 orm_data = self.model.model_validate(data)
37 db_session.add(orm_data)
38 return orm_data
40 async def insert_batch(
41 self, *, data_list: List[Any], db_session: Any = None
42 ) -> int:
43 db_session = db_session or self.db.session
44 orm_datas = []
45 for data in data_list:
46 orm_datas.append(self.model.model_validate(data))
47 statement = insert(self.model).values([data.model_dump() for data in orm_datas])
48 await db_session.execute(statement)
49 return len(data_list)
51 async def select_by_id(self, *, id: Any, db_session: Any = None) -> Any:
52 db_session = db_session or self.db.session
53 statement = select(self.model).where(self.model.id == id)
54 response = await db_session.execute(statement)
55 return response.scalar_one_or_none()
57 async def select_by_ids(
58 self, *, ids: List[Any], batch_size: int = 1000, db_session: Any = None
59 ) -> List[Any]:
60 db_session = db_session or self.db.session
61 result_set = []
62 for i in range(0, len(ids), batch_size):
63 batch_ids = ids[i : i + batch_size]
64 statement = select(self.model).where(self.model.id.in_(batch_ids))
65 results = await db_session.exec(statement).all()
66 result_set.extend(results)
67 return result_set
69 async def select_count(self, *, db_session: Any = None) -> int:
70 db_session = db_session or self.db.session
71 response = await db_session.execute(
72 select(func.count()).select_from(select(self.model).subquery())
73 )
74 return response.scalar_one()
76 async def select_list(
77 self, *, page: int = 1, size: int = 100, query: Any, db_session: Any = None
78 ) -> List[Any]:
79 db_session = db_session or self.db.session
80 if query is None:
81 query = (
82 select(self.model)
83 .offset((page - 1) * size)
84 .limit(size)
85 .order_by(self.model.id)
86 )
87 response = await db_session.execute(query)
88 return response.scalars().all()
90 async def select_list_ordered(
91 self,
92 *,
93 page: int = 1,
94 size: int = 100,
95 query: Any,
96 order_by: Any,
97 sort_order: Any,
98 db_session: Any = None,
99 ) -> List[Any]:
100 db_session = db_session or self.db.session
101 columns = self.model.__table__.columns
102 if order_by is None or order_by not in columns:
103 order_by = "id"
104 if sort_order == SortEnum.ascending:
105 query = (
106 select(self.model)
107 .offset((page - 1) * size)
108 .limit(size)
109 .order_by(columns[order_by].asc())
110 )
111 else:
112 query = (
113 select(self.model)
114 .offset((page - 1) * size)
115 .limit(size)
116 .order_by(columns[order_by].desc())
117 )
118 response = await db_session.execute(query)
119 return response.scalars().all()
121 async def select_list_page(
122 self, *, params: Any, query: Any, db_session: Any = None
123 ) -> List[Any]:
124 db_session = db_session or self.db.session
125 if query is None:
126 query = select(self.model)
127 response = await paginate(db_session, query, params)
128 return response
130 async def select_list_page_ordered(
131 self,
132 *,
133 params: Any,
134 query: Any = None,
135 order_by: Any = None,
136 sort_order: Any = None,
137 db_session: Any = None,
138 ) -> List[Any]:
139 db_session = db_session or self.db.session
140 columns = self.model.__table__.columns
141 if order_by is None or order_by not in columns:
142 order_by = "id"
143 if query is None:
144 if sort_order == SortEnum.ascending:
145 query = select(self.model).order_by(columns[order_by].asc())
146 else:
147 query = select(self.model).order_by(columns[order_by].desc())
148 return await paginate(db_session, query, params)
150 async def update_by_id(self, *, data: Any, db_session: Any = None) -> int:
151 db_session = db_session or self.db.session
152 query = select(self.model).where(self.model.id == data.id)
153 result = await db_session.execute(query)
154 if result is None:
155 return 0
156 db_data = result.scalar_one()
157 if isinstance(data, dict):
158 update_data = data
159 else:
160 update_data = data.model_dump(exclude_unset=True)
161 for field in update_data:
162 setattr(db_data, field, update_data[field])
164 db_session.add(db_data)
165 return db_data
167 async def update_batch_by_ids(
168 self, *, data_list: List[Any], db_session: Any = None
169 ) -> int:
170 db_session = db_session or self.db.session
171 for data in data_list:
172 if hasattr(data, "id"):
173 statement = (
174 update(self.model)
175 .where(self.model.id == data.id)
176 .values(**data.dict(exclude_unset=True))
177 )
178 await db_session.execute(statement)
179 return len(data_list)
181 async def delete_by_id(self, *, id: Any, db_session: Any = None) -> int:
182 db_session = db_session or self.db.session
183 statement = select(self.model).where(self.model.id == id)
184 response = await db_session.execute(statement)
185 data = response.scalar_one()
186 return await db_session.delete(data)
188 async def delete_batch_by_ids(
189 self, *, ids: List[Any], db_session: Any = None
190 ) -> int:
191 db_session = db_session or self.db.session
192 statement = delete(self.model).where(self.model.id.in_(ids))
193 result = await db_session.execute(statement)
194 return result.rowcount