Coverage for src/pydal2sql_core/core.py: 100%

98 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-22 11:38 +0200

1""" 

2Main functionality. 

3""" 

4 

5import functools 

6import pickle # nosec: B403 

7import typing 

8from pathlib import Path 

9from typing import Any 

10 

11from pydal.adapters import MySQL, Postgre, SQLite 

12from pydal.dialects import ( 

13 Dialect, 

14 MySQLDialect, 

15 PostgreDialect, 

16 SQLDialect, 

17 SQLiteDialect, 

18) 

19from pydal.migrator import Migrator 

20from pydal.objects import Table 

21 

22from .helpers import TempdirOrExistingDir, get_typing_args 

23from .types import ( 

24 SUPPORTED_DATABASE_TYPES, 

25 SUPPORTED_DATABASE_TYPES_WITH_ALIASES, 

26 CustomAdapter, 

27 DummyDAL, 

28 SQLAdapter, 

29 UniversalSet, 

30) 

31 

32 

33def sql_not_null(self: SQLDialect, default: Any, field_type: Any) -> str: 

34 """ 

35 Generate a SQL NOT NULL constraint for a field. 

36 

37 If the default value of the field is callable (e.g., a function like uuid.uuid, datetime.now or a lambda), 

38 the function returns "NOT NULL". Otherwise, it returns "NOT NULL DEFAULT %s" where %s is the 

39 representation of the default value in SQL. 

40 

41 Args: 

42 self (SQLDialect): The SQL dialect to use. 

43 default (Any): The default value of the field. 

44 field_type (Any): The type of the field. 

45 

46 Returns: 

47 str: A string representing the SQL NOT NULL constraint for the field. 

48 """ 

49 # default: 

50 # if field.notnull and field.default is not None: 

51 # return "NOT NULL DEFAULT %s" % self.adapter.represent(default, field_type) 

52 # but if 'default' is not a static value (-> callable), 

53 # it should not be hardcoded in the migration statement (e.g. default=uuid.uuid, datetime.now etc.) 

54 if callable(default): 

55 return "NOT NULL" 

56 else: 

57 return "NOT NULL DEFAULT %s" % self.adapter.represent(default, field_type) 

58 

59 

60def _modify_migrator(self: Migrator) -> Migrator: 

61 """ 

62 Modify the SQL NOT NULL constraint logic of a Migrator object. 

63 

64 If the Migrator's SQL dialect uses the base SQL NOT NULL logic, this function replaces it with 

65 the logic defined in the sql_not_null function. Otherwise, it leaves the Migrator unchanged. 

66 

67 Args: 

68 self (Migrator): The Migrator object to modify. 

69 

70 Returns: 

71 Migrator: The modified Migrator object. 

72 """ 

73 # __func__ to get the function underneath a bound method: 

74 if self.adapter.dialect.not_null.__func__ == SQLDialect.not_null: 

75 # only modify base SQL notnull. 

76 # if dialect has modified logic, that should just be used. 

77 # bind 'self' parameter already: 

78 bound_method = functools.partial(sql_not_null, self.adapter.dialect) 

79 # monkey patch default logic: 

80 self.adapter.dialect.not_null = bound_method 

81 

82 return self 

83 

84 

85def _build_dummy_migrator(_driver_name: SUPPORTED_DATABASE_TYPES_WITH_ALIASES, /, db_folder: str) -> Migrator: 

86 """ 

87 Create a Migrator specific to the sql dialect of _driver_name. 

88 """ 

89 db = DummyDAL(None, migrate=False, folder=db_folder) 

90 

91 aliases = { 

92 "postgresql": "psycopg2", 

93 "postgres": "psycopg2", 

94 "psql": "psycopg2", 

95 "sqlite": "sqlite3", 

96 "sqlite:memory": "sqlite3", 

97 "mysql": "pymysql", 

98 } 

99 

100 driver_name = _driver_name.lower() 

101 driver_name = aliases.get(driver_name, driver_name) 

102 

103 if driver_name not in get_typing_args(SUPPORTED_DATABASE_TYPES): 

104 raise ValueError( 

105 f"Unsupported database type {driver_name}. " 

106 f"Choose one of {get_typing_args(SUPPORTED_DATABASE_TYPES_WITH_ALIASES)}" 

107 ) 

108 

109 adapters_per_database: dict[str, typing.Type[SQLAdapter]] = { 

110 "psycopg2": Postgre, 

111 "sqlite3": SQLite, 

112 "pymysql": MySQL, 

113 } 

114 

115 dialects_per_database: dict[str, typing.Type[Dialect]] = { 

116 "psycopg2": PostgreDialect, 

117 "sqlite3": SQLiteDialect, 

118 "pymysql": MySQLDialect, 

119 } 

120 

121 adapter_cls = adapters_per_database[driver_name] 

122 

123 installed_driver = db._drivers_available.get(driver_name) 

124 

125 if not installed_driver: # pragma: no cover 

126 raise ValueError(f"Please install the correct driver for database type {driver_name}") 

127 

128 sql_dialect = dialects_per_database[driver_name] 

129 

130 class DummyAdapter(CustomAdapter): 

131 driver = installed_driver 

132 dbengine = adapter_cls.dbengine 

133 _types = adapter_cls.types 

134 

135 commit_on_alter_table = True 

136 

137 @property 

138 def types(self): 

139 # special type that ensures 'x in types' is always true 

140 return UniversalSet(self._types) 

141 

142 adapter = DummyAdapter(db, "", adapter_args={"driver": installed_driver}) 

143 

144 adapter.dialect = sql_dialect(adapter) 

145 db._adapter = adapter 

146 

147 dummy_migrator = Migrator(adapter) 

148 

149 return _modify_migrator(dummy_migrator) 

150 

151 

152def generate_create_statement( 

153 define_table: Table, db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, *, db_folder: str = None 

154) -> str: 

155 """ 

156 Given a Table object (result of `db.define_table('mytable')` or simply db.mytable) \ 

157 and a db type (e.g. postgres, sqlite, mysql), generate the `CREATE TABLE` SQL for that dialect. 

158 

159 If no db_type is supplied, the type is guessed from the specified table. 

160 However, your db_type can differ from the current database used. 

161 You can even use a dummy database to generate SQL code with: 

162 `db = pydal.DAL(None, migrate=False)` 

163 

164 db_folder is the database folder where migration (`.table`) files are stored. 

165 By default, a random temporary dir is created. 

166 """ 

167 if not db_type: 

168 db_type = getattr(define_table._db, "_dbname", None) 

169 

170 if db_type is None: 

171 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

172 

173 with TempdirOrExistingDir(db_folder) as db_folder: 

174 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

175 

176 sql: str = migrator.create_table( 

177 define_table, 

178 migrate=False, 

179 fake_migrate=True, 

180 ) 

181 

182 return sql 

183 

184 

185def sql_fields_through_tablefile( 

186 define_table: Table, 

187 db_folder: typing.Optional[str | Path] = None, 

188 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

189) -> dict[str, Any]: 

190 """ 

191 Generate SQL fields for the given `Table` object by simulating migration via a table file. 

192 

193 Args: 

194 define_table (Table): The `Table` object representing the table for which SQL fields are generated. 

195 db_folder (str or Path, optional): The path to the database folder or directory to use. If not specified, 

196 a temporary directory is used for the operation. Defaults to None. 

197 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

198 "mysql", etc.). If not provided, the database type will be guessed based on the `define_table` object. 

199 If the guess fails, a ValueError is raised. Defaults to None. 

200 

201 Returns: 

202 dict[str, Any]: A dictionary containing the generated SQL fields for the `Table` object. The keys 

203 of the dictionary are field names, and the values are additional field information. 

204 

205 Raises: 

206 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table` object. 

207 """ 

208 if not db_type: 

209 db_type = getattr(define_table._db, "_dbname", None) 

210 

211 if db_type is None: 

212 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

213 

214 with TempdirOrExistingDir(db_folder) as db_folder: 

215 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

216 

217 migrator.create_table( 

218 define_table, 

219 migrate=True, 

220 fake_migrate=True, 

221 ) 

222 

223 with (Path(db_folder) / define_table._dbt).open("rb") as tfile: 

224 loaded_tables = pickle.load(tfile) # nosec B301 

225 

226 return typing.cast(dict[str, Any], loaded_tables) 

227 

228 

229def generate_alter_statement( 

230 define_table_old: Table, 

231 define_table_new: Table, 

232 /, 

233 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

234 *, 

235 db_folder: str = None, 

236) -> str: 

237 """ 

238 Generate SQL ALTER statements to update the `define_table_old` to `define_table_new`. 

239 

240 Args: 

241 define_table_old (Table): The `Table` object representing the old version of the table. 

242 define_table_new (Table): The `Table` object representing the new version of the table. 

243 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

244 "mysql", etc.). If not provided, the database type will be guessed based on the `_db` attribute of the 

245 `define_table_old` and `define_table_new` objects. 

246 If the guess fails, a ValueError is raised. Defaults to None. 

247 db_folder (str, optional): The path to the database folder or directory to use. If not specified, 

248 a temporary directory is used for the operation. Defaults to None. 

249 

250 Returns: 

251 str: A string containing SQL ALTER statements that update the `define_table_old` to `define_table_new`. 

252 

253 Raises: 

254 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table_old` and 

255 `define_table_new` objects. 

256 """ 

257 if not db_type: 

258 db_type = getattr(define_table_old._db, "_dbname", None) or getattr(define_table_new._db, "_dbname", None) 

259 

260 if db_type is None: 

261 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

262 

263 result = "" 

264 

265 # other db_folder than new! 

266 old_fields = sql_fields_through_tablefile(define_table_old, db_type=db_type, db_folder=None) 

267 

268 with TempdirOrExistingDir(db_folder) as db_folder: 

269 db_folder_path = Path(db_folder) 

270 new_fields = sql_fields_through_tablefile(define_table_new, db_type=db_type, db_folder=db_folder) 

271 

272 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

273 

274 sql_log = db_folder_path / "sql.log" 

275 sql_log.unlink(missing_ok=True) # remove old crap 

276 

277 original_db_old = define_table_old._db 

278 original_db_new = define_table_new._db 

279 try: 

280 define_table_old._db = migrator.db 

281 define_table_new._db = migrator.db 

282 

283 migrator.migrate_table( 

284 define_table_new, 

285 new_fields, 

286 old_fields, 

287 new_fields, 

288 str(db_folder_path / "<deprecated>"), 

289 fake_migrate=True, 

290 ) 

291 

292 if not sql_log.exists(): 

293 # no changes! 

294 return "" 

295 

296 with sql_log.open() as f: 

297 for line in f: 

298 if not line.startswith(("ALTER", "UPDATE")): 

299 continue 

300 

301 result += line 

302 finally: 

303 define_table_new._db = original_db_new 

304 define_table_old._db = original_db_old 

305 

306 return result 

307 

308 

309def generate_sql( 

310 define_table: Table, 

311 define_table_new: typing.Optional[Table] = None, 

312 /, 

313 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

314 *, 

315 db_folder: str = None, 

316) -> str: 

317 """ 

318 Generate SQL statements based on the provided `Table` object or a comparison of two `Table` objects. 

319 

320 If `define_table_new` is provided, the function generates ALTER statements to update `define_table` to 

321 `define_table_new`. If `define_table_new` is not provided, the function generates CREATE statements for 

322 `define_table`. 

323 

324 Args: 

325 define_table (Table): The `Table` object representing the table to generate SQL for. 

326 define_table_new (Table, optional): The `Table` object representing the new version of the table 

327 (used to generate ALTER statements). Defaults to None. 

328 db_type (str or SUPPORTED_DATABASE_TYPES_WITH_ALIASES, optional): The type of the database (e.g., "postgres", 

329 "mysql", etc.). If not provided, the database type will be guessed based on the `_db` attribute of the 

330 `define_table` object. If the guess fails, a ValueError is raised. Defaults to None. 

331 db_folder (str, optional): The path to the database folder or directory to use. If not specified, 

332 a temporary directory is used for the operation. Defaults to None. 

333 

334 Returns: 

335 str: A string containing the generated SQL statements. 

336 

337 Raises: 

338 ValueError: If the `db_type` is not provided, and it cannot be guessed from the `define_table` object. 

339 """ 

340 if define_table_new: 

341 return generate_alter_statement(define_table, define_table_new, db_type=db_type, db_folder=db_folder) 

342 else: 

343 return generate_create_statement(define_table, db_type=db_type, db_folder=db_folder)