Coverage for fss\common\persistence\sqlmodel_impl.py: 51%

114 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 15:26 +0800

1"""Sqlmodel impl that do database operations""" 

2 

3from typing import Generic, TypeVar, List, Any, Type, Union 

4 

5from fastapi_pagination.ext.sqlmodel import paginate 

6from pydantic import BaseModel 

7from sqlmodel import SQLModel, select, func, insert, update, delete 

8from sqlmodel.ext.asyncio.session import AsyncSession 

9 

10from fss.common.enum.enum import SortEnum 

11from fss.common.persistence.base_mapper import BaseMapper 

12from fss.middleware.db_session_middleware import db 

13 

14ModelType = TypeVar("ModelType", bound=SQLModel) 

15CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) 

16UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) 

17SchemaType = TypeVar("SchemaType", bound=BaseModel) 

18T = TypeVar("T", bound=SQLModel) 

19 

20 

21class SqlModelMapper(Generic[ModelType], BaseMapper): 

22 def __init__(self, model: Type[ModelType]): 

23 self.model = model 

24 self.db = db 

25 

26 def get_db_session(self) -> Type[Any]: 

27 return self.db 

28 

29 async def insert( 

30 self, 

31 *, 

32 data: Union[ModelType, SchemaType], 

33 db_session: Union[AsyncSession, None] = None, 

34 ) -> int: 

35 db_session = db_session or self.db.session 

36 orm_data = self.model.model_validate(data) 

37 db_session.add(orm_data) 

38 return orm_data 

39 

40 async def insert_batch( 

41 self, *, data_list: List[Any], db_session: Any = None 

42 ) -> int: 

43 db_session = db_session or self.db.session 

44 orm_datas = [] 

45 for data in data_list: 

46 orm_datas.append(self.model.model_validate(data)) 

47 statement = insert(self.model).values([data.model_dump() for data in orm_datas]) 

48 await db_session.execute(statement) 

49 return len(data_list) 

50 

51 async def select_by_id(self, *, id: Any, db_session: Any = None) -> Any: 

52 db_session = db_session or self.db.session 

53 statement = select(self.model).where(self.model.id == id) 

54 response = await db_session.execute(statement) 

55 return response.scalar_one_or_none() 

56 

57 async def select_by_ids( 

58 self, *, ids: List[Any], batch_size: int = 1000, db_session: Any = None 

59 ) -> List[Any]: 

60 db_session = db_session or self.db.session 

61 result_set = [] 

62 for i in range(0, len(ids), batch_size): 

63 batch_ids = ids[i : i + batch_size] 

64 statement = select(self.model).where(self.model.id.in_(batch_ids)) 

65 results = await db_session.exec(statement).all() 

66 result_set.extend(results) 

67 return result_set 

68 

69 async def select_count(self, *, db_session: Any = None) -> int: 

70 db_session = db_session or self.db.session 

71 response = await db_session.execute( 

72 select(func.count()).select_from(select(self.model).subquery()) 

73 ) 

74 return response.scalar_one() 

75 

76 async def select_list( 

77 self, *, page: int = 1, size: int = 100, query: Any, db_session: Any = None 

78 ) -> List[Any]: 

79 db_session = db_session or self.db.session 

80 if query is None: 

81 query = ( 

82 select(self.model) 

83 .offset((page - 1) * size) 

84 .limit(size) 

85 .order_by(self.model.id) 

86 ) 

87 response = await db_session.execute(query) 

88 return response.scalars().all() 

89 

90 async def select_list_ordered( 

91 self, 

92 *, 

93 page: int = 1, 

94 size: int = 100, 

95 query: Any, 

96 order_by: Any, 

97 sort_order: Any, 

98 db_session: Any = None, 

99 ) -> List[Any]: 

100 db_session = db_session or self.db.session 

101 columns = self.model.__table__.columns 

102 if order_by is None or order_by not in columns: 

103 order_by = "id" 

104 if sort_order == SortEnum.ascending: 

105 query = ( 

106 select(self.model) 

107 .offset((page - 1) * size) 

108 .limit(size) 

109 .order_by(columns[order_by].asc()) 

110 ) 

111 else: 

112 query = ( 

113 select(self.model) 

114 .offset((page - 1) * size) 

115 .limit(size) 

116 .order_by(columns[order_by].desc()) 

117 ) 

118 response = await db_session.execute(query) 

119 return response.scalars().all() 

120 

121 async def select_list_page( 

122 self, *, params: Any, query: Any, db_session: Any = None 

123 ) -> List[Any]: 

124 db_session = db_session or self.db.session 

125 if query is None: 

126 query = select(self.model) 

127 response = await paginate(db_session, query, params) 

128 return response 

129 

130 async def select_list_page_ordered( 

131 self, 

132 *, 

133 params: Any, 

134 query: Any = None, 

135 order_by: Any = None, 

136 sort_order: Any = None, 

137 db_session: Any = None, 

138 ) -> List[Any]: 

139 db_session = db_session or self.db.session 

140 columns = self.model.__table__.columns 

141 if order_by is None or order_by not in columns: 

142 order_by = "id" 

143 if query is None: 

144 if sort_order == SortEnum.ascending: 

145 query = select(self.model).order_by(columns[order_by].asc()) 

146 else: 

147 query = select(self.model).order_by(columns[order_by].desc()) 

148 return await paginate(db_session, query, params) 

149 

150 async def update_by_id(self, *, data: Any, db_session: Any = None) -> int: 

151 db_session = db_session or self.db.session 

152 query = select(self.model).where(self.model.id == data.id) 

153 result = await db_session.execute(query) 

154 if result is None: 

155 return 0 

156 db_data = result.scalar_one() 

157 if isinstance(data, dict): 

158 update_data = data 

159 else: 

160 update_data = data.model_dump(exclude_unset=True) 

161 for field in update_data: 

162 setattr(db_data, field, update_data[field]) 

163 

164 db_session.add(db_data) 

165 return db_data 

166 

167 async def update_batch_by_ids( 

168 self, *, data_list: List[Any], db_session: Any = None 

169 ) -> int: 

170 db_session = db_session or self.db.session 

171 for data in data_list: 

172 if hasattr(data, "id"): 

173 statement = ( 

174 update(self.model) 

175 .where(self.model.id == data.id) 

176 .values(**data.dict(exclude_unset=True)) 

177 ) 

178 await db_session.execute(statement) 

179 return len(data_list) 

180 

181 async def delete_by_id(self, *, id: Any, db_session: Any = None) -> int: 

182 db_session = db_session or self.db.session 

183 statement = select(self.model).where(self.model.id == id) 

184 response = await db_session.execute(statement) 

185 data = response.scalar_one() 

186 return await db_session.delete(data) 

187 

188 async def delete_batch_by_ids( 

189 self, *, ids: List[Any], db_session: Any = None 

190 ) -> int: 

191 db_session = db_session or self.db.session 

192 statement = delete(self.model).where(self.model.id.in_(ids)) 

193 result = await db_session.execute(statement) 

194 return result.rowcount