Coverage for src/pydal2sql_core/core.py: 100%

95 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-08-05 17:25 +0200

1""" 

2Main functionality. 

3""" 

4 

5import functools 

6import pickle # nosec: B403 

7import typing 

8from pathlib import Path 

9from typing import Any 

10 

11from pydal.adapters import MySQL, Postgre, SQLite 

12from pydal.dialects import ( 

13 Dialect, 

14 MySQLDialect, 

15 PostgreDialect, 

16 SQLDialect, 

17 SQLiteDialect, 

18) 

19from pydal.migrator import Migrator 

20from pydal.objects import Table 

21 

22from .helpers import TempdirOrExistingDir, get_typing_args 

23from .types import ( 

24 SUPPORTED_DATABASE_TYPES, 

25 SUPPORTED_DATABASE_TYPES_WITH_ALIASES, 

26 CustomAdapter, 

27 DummyDAL, 

28 SQLAdapter, 

29) 

30 

31 

32def sql_not_null(self: SQLDialect, default: Any, field_type: Any) -> str: 

33 """ 

34 Generate a SQL NOT NULL constraint for a field. 

35 

36 If the default value of the field is callable (e.g., a function like uuid.uuid, datetime.now or a lambda), 

37 the function returns "NOT NULL". Otherwise, it returns "NOT NULL DEFAULT %s" where %s is the 

38 representation of the default value in SQL. 

39 

40 Args: 

41 self (SQLDialect): The SQL dialect to use. 

42 default (Any): The default value of the field. 

43 field_type (Any): The type of the field. 

44 

45 Returns: 

46 str: A string representing the SQL NOT NULL constraint for the field. 

47 """ 

48 # default: 

49 # if field.notnull and field.default is not None: 

50 # return "NOT NULL DEFAULT %s" % self.adapter.represent(default, field_type) 

51 # but if 'default' is not a static value (-> callable), 

52 # it should not be hardcoded in the migration statement (e.g. default=uuid.uuid, datetime.now etc.) 

53 if callable(default): 

54 return "NOT NULL" 

55 else: 

56 return "NOT NULL DEFAULT %s" % self.adapter.represent(default, field_type) 

57 

58 

59def _modify_migrator(self: Migrator) -> Migrator: 

60 """ 

61 Modify the SQL NOT NULL constraint logic of a Migrator object. 

62 

63 If the Migrator's SQL dialect uses the base SQL NOT NULL logic, this function replaces it with 

64 the logic defined in the sql_not_null function. Otherwise, it leaves the Migrator unchanged. 

65 

66 Args: 

67 self (Migrator): The Migrator object to modify. 

68 

69 Returns: 

70 Migrator: The modified Migrator object. 

71 """ 

72 # __func__ to get the function underneath a bound method: 

73 if self.adapter.dialect.not_null.__func__ == SQLDialect.not_null: 

74 # only modify base SQL notnull. 

75 # if dialect has modified logic, that should just be used. 

76 # bind 'self' parameter already: 

77 bound_method = functools.partial(sql_not_null, self.adapter.dialect) 

78 # monkey patch default logic: 

79 self.adapter.dialect.not_null = bound_method 

80 

81 return self 

82 

83 

84def _build_dummy_migrator(_driver_name: SUPPORTED_DATABASE_TYPES_WITH_ALIASES, /, db_folder: str) -> Migrator: 

85 """ 

86 Create a Migrator specific to the sql dialect of _driver_name. 

87 """ 

88 db = DummyDAL(None, migrate=False, folder=db_folder) 

89 

90 aliases = { 

91 "postgresql": "psycopg2", 

92 "postgres": "psycopg2", 

93 "psql": "psycopg2", 

94 "sqlite": "sqlite3", 

95 "sqlite:memory": "sqlite3", 

96 "mysql": "pymysql", 

97 } 

98 

99 driver_name = _driver_name.lower() 

100 driver_name = aliases.get(driver_name, driver_name) 

101 

102 if driver_name not in get_typing_args(SUPPORTED_DATABASE_TYPES): 

103 raise ValueError( 

104 f"Unsupported database type {driver_name}. " 

105 f"Choose one of {get_typing_args(SUPPORTED_DATABASE_TYPES_WITH_ALIASES)}" 

106 ) 

107 

108 adapters_per_database: dict[str, typing.Type[SQLAdapter]] = { 

109 "psycopg2": Postgre, 

110 "sqlite3": SQLite, 

111 "pymysql": MySQL, 

112 } 

113 

114 dialects_per_database: dict[str, typing.Type[Dialect]] = { 

115 "psycopg2": PostgreDialect, 

116 "sqlite3": SQLiteDialect, 

117 "pymysql": MySQLDialect, 

118 } 

119 

120 adapter_cls = adapters_per_database[driver_name] 

121 

122 installed_driver = db._drivers_available.get(driver_name) 

123 

124 if not installed_driver: # pragma: no cover 

125 raise ValueError(f"Please install the correct driver for database type {driver_name}") 

126 

127 sql_dialect = dialects_per_database[driver_name] 

128 

129 class DummyAdapter(CustomAdapter): 

130 types = adapter_cls.types 

131 driver = installed_driver 

132 dbengine = adapter_cls.dbengine 

133 

134 commit_on_alter_table = True 

135 

136 adapter = DummyAdapter(db, "", adapter_args={"driver": installed_driver}) 

137 

138 adapter.dialect = sql_dialect(adapter) 

139 db._adapter = adapter 

140 

141 dummy_migrator = Migrator(adapter) 

142 

143 return _modify_migrator(dummy_migrator) 

144 

145 

146def generate_create_statement( 

147 define_table: Table, db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, *, db_folder: str = None 

148) -> str: 

149 """ 

150 Given a Table object (result of `db.define_table('mytable')` or simply db.mytable) \ 

151 and a db type (e.g. postgres, sqlite, mysql), generate the `CREATE TABLE` SQL for that dialect. 

152 

153 If no db_type is supplied, the type is guessed from the specified table. 

154 However, your db_type can differ from the current database used. 

155 You can even use a dummy database to generate SQL code with: 

156 `db = pydal.DAL(None, migrate=False)` 

157 

158 db_folder is the database folder where migration (`.table`) files are stored. 

159 By default, a random temporary dir is created. 

160 """ 

161 if not db_type: 

162 db_type = getattr(define_table._db, "_dbname", None) 

163 

164 if db_type is None: 

165 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

166 

167 with TempdirOrExistingDir(db_folder) as db_folder: 

168 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

169 

170 sql: str = migrator.create_table( 

171 define_table, 

172 migrate=False, 

173 fake_migrate=True, 

174 ) 

175 

176 return sql 

177 

178 

179def sql_fields_through_tablefile( 

180 define_table: Table, 

181 db_folder: typing.Optional[str | Path] = None, 

182 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

183) -> dict[str, Any]: 

184 """ 

185 Generate SQL fields for the given `Table` object by simulating migration via a table file. 

186 

187 Args: 

188 define_table (Table): The `Table` object representing the table for which SQL fields are generated. 

189 db_folder (str or Path, optional): The path to the database folder or directory to use. If not specified, 

190 a temporary directory is used for the operation. Defaults to None. 

191 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

192 "mysql", etc.). If not provided, the database type will be guessed based on the `define_table` object. 

193 If the guess fails, a ValueError is raised. Defaults to None. 

194 

195 Returns: 

196 dict[str, Any]: A dictionary containing the generated SQL fields for the `Table` object. The keys 

197 of the dictionary are field names, and the values are additional field information. 

198 

199 Raises: 

200 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table` object. 

201 """ 

202 if not db_type: 

203 db_type = getattr(define_table._db, "_dbname", None) 

204 

205 if db_type is None: 

206 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

207 

208 with TempdirOrExistingDir(db_folder) as db_folder: 

209 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

210 

211 migrator.create_table( 

212 define_table, 

213 migrate=True, 

214 fake_migrate=True, 

215 ) 

216 

217 with (Path(db_folder) / define_table._dbt).open("rb") as tfile: 

218 loaded_tables = pickle.load(tfile) # nosec B301 

219 

220 return typing.cast(dict[str, Any], loaded_tables) 

221 

222 

223def generate_alter_statement( 

224 define_table_old: Table, 

225 define_table_new: Table, 

226 /, 

227 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

228 *, 

229 db_folder: str = None, 

230) -> str: 

231 """ 

232 Generate SQL ALTER statements to update the `define_table_old` to `define_table_new`. 

233 

234 Args: 

235 define_table_old (Table): The `Table` object representing the old version of the table. 

236 define_table_new (Table): The `Table` object representing the new version of the table. 

237 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

238 "mysql", etc.). If not provided, the database type will be guessed based on the `_db` attribute of the 

239 `define_table_old` and `define_table_new` objects. 

240 If the guess fails, a ValueError is raised. Defaults to None. 

241 db_folder (str, optional): The path to the database folder or directory to use. If not specified, 

242 a temporary directory is used for the operation. Defaults to None. 

243 

244 Returns: 

245 str: A string containing SQL ALTER statements that update the `define_table_old` to `define_table_new`. 

246 

247 Raises: 

248 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table_old` and 

249 `define_table_new` objects. 

250 """ 

251 if not db_type: 

252 db_type = getattr(define_table_old._db, "_dbname", None) or getattr(define_table_new._db, "_dbname", None) 

253 

254 if db_type is None: 

255 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

256 

257 result = "" 

258 

259 # other db_folder than new! 

260 old_fields = sql_fields_through_tablefile(define_table_old, db_type=db_type, db_folder=None) 

261 

262 with TempdirOrExistingDir(db_folder) as db_folder: 

263 db_folder_path = Path(db_folder) 

264 new_fields = sql_fields_through_tablefile(define_table_new, db_type=db_type, db_folder=db_folder) 

265 

266 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

267 

268 sql_log = db_folder_path / "sql.log" 

269 sql_log.unlink(missing_ok=True) # remove old crap 

270 

271 original_db_old = define_table_old._db 

272 original_db_new = define_table_new._db 

273 try: 

274 define_table_old._db = migrator.db 

275 define_table_new._db = migrator.db 

276 

277 migrator.migrate_table( 

278 define_table_new, 

279 new_fields, 

280 old_fields, 

281 new_fields, 

282 str(db_folder_path / "<deprecated>"), 

283 fake_migrate=True, 

284 ) 

285 

286 if not sql_log.exists(): 

287 # no changes! 

288 return "" 

289 

290 with sql_log.open() as f: 

291 for line in f: 

292 if not line.startswith(("ALTER", "UPDATE")): 

293 continue 

294 

295 result += line 

296 finally: 

297 define_table_new._db = original_db_new 

298 define_table_old._db = original_db_old 

299 

300 return result 

301 

302 

303def generate_sql( 

304 define_table: Table, 

305 define_table_new: typing.Optional[Table] = None, 

306 /, 

307 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

308 *, 

309 db_folder: str = None, 

310) -> str: 

311 """ 

312 Generate SQL statements based on the provided `Table` object or a comparison of two `Table` objects. 

313 

314 If `define_table_new` is provided, the function generates ALTER statements to update `define_table` to 

315 `define_table_new`. If `define_table_new` is not provided, the function generates CREATE statements for 

316 `define_table`. 

317 

318 Args: 

319 define_table (Table): The `Table` object representing the table to generate SQL for. 

320 define_table_new (Table, optional): The `Table` object representing the new version of the table 

321 (used to generate ALTER statements). Defaults to None. 

322 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

323 "mysql", etc.). If not provided, the database type will be guessed based on the `_db` attribute of the 

324 `define_table` object. If the guess fails, a ValueError is raised. Defaults to None. 

325 db_folder (str, optional): The path to the database folder or directory to use. If not specified, 

326 a temporary directory is used for the operation. Defaults to None. 

327 

328 Returns: 

329 str: A string containing the generated SQL statements. 

330 

331 Raises: 

332 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table` object. 

333 """ 

334 if define_table_new: 

335 return generate_alter_statement(define_table, define_table_new, db_type=db_type, db_folder=db_folder) 

336 else: 

337 return generate_create_statement(define_table, db_type=db_type, db_folder=db_folder)