Coverage for src/pydal2sql/core.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-31 13:34 +0200

1""" 

2Main functionality. 

3""" 

4import pickle # nosec: B403 

5import typing 

6from pathlib import Path 

7 

8import pydal 

9from pydal.adapters import MySQL, Postgre, SQLAdapter, SQLite 

10from pydal.migrator import Migrator 

11from pydal.objects import Table 

12 

13from .helpers import TempdirOrExistingDir, get_typing_args 

14from .types import SUPPORTED_DATABASE_TYPES, SUPPORTED_DATABASE_TYPES_WITH_ALIASES 

15 

16 

17class DummyDAL(pydal.DAL): # type: ignore 

18 def commit(self) -> None: 

19 """ 

20 Do Nothing 

21 """ 

22 

23 

24def _build_dummy_migrator(_driver_name: SUPPORTED_DATABASE_TYPES_WITH_ALIASES, /, db_folder: str) -> Migrator: 

25 """ 

26 Create a Migrator specific to the sql dialect of _driver_name. 

27 """ 

28 db = DummyDAL(None, migrate=False, folder=db_folder) 

29 

30 aliases = { 

31 "postgresql": "psycopg2", 

32 "postgres": "psycopg2", 

33 "psql": "psycopg2", 

34 "sqlite": "sqlite3", 

35 "mysql": "pymysql", 

36 } 

37 

38 driver_name = _driver_name.lower() 

39 driver_name = aliases.get(driver_name, driver_name) 

40 

41 if driver_name not in get_typing_args(SUPPORTED_DATABASE_TYPES): 

42 raise ValueError( 

43 f"Unsupported database type {driver_name}. " 

44 f"Choose one of {get_typing_args(SUPPORTED_DATABASE_TYPES_WITH_ALIASES)}" 

45 ) 

46 

47 adapters_per_database: dict[str, typing.Type[SQLAdapter]] = { 

48 "psycopg2": Postgre, 

49 "sqlite3": SQLite, 

50 "pymysql": MySQL, 

51 } 

52 

53 adapter = adapters_per_database[driver_name] 

54 

55 installed_driver = db._drivers_available.get(driver_name) 

56 

57 if not installed_driver: # pragma: no cover 

58 raise ValueError(f"Please install the correct driver for database type {driver_name}") 

59 

60 class DummyAdaptor(SQLAdapter): # type: ignore 

61 types = adapter.types 

62 driver = installed_driver 

63 dbengine = adapter.dbengine 

64 

65 commit_on_alter_table = True 

66 

67 adapter = DummyAdaptor(db, "", adapter_args={"driver": installed_driver}) 

68 db._adapter = adapter 

69 return Migrator(adapter) 

70 

71 

72def generate_create_statement( 

73 define_table: Table, db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, *, db_folder: str = None 

74) -> str: 

75 """ 

76 Given a Table object (result of `db.define_table('mytable')` or simply db.mytable) \ 

77 and a db type (e.g. postgres, sqlite, mysql), generate the `CREATE TABLE` SQL for that dialect. 

78 

79 If no db_type is supplied, the type is guessed from the specified table. 

80 However, your db_type can differ from the current database used. 

81 You can even use a dummy database to generate SQL code with: 

82 `db = pydal.DAL(None, migrate=False)` 

83 

84 db_folder is the database folder where migration (`.table`) files are stored. 

85 By default, a random temporary dir is created. 

86 """ 

87 if not db_type: 

88 db_type = getattr(define_table._db, "_dbname", None) 

89 

90 if db_type is None: 

91 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

92 

93 with TempdirOrExistingDir(db_folder) as db_folder: 

94 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

95 

96 sql: str = migrator.create_table( 

97 define_table, 

98 migrate=True, 

99 fake_migrate=True, 

100 ) 

101 return sql 

102 

103 

104def sql_fields_through_tablefile( 

105 define_table: Table, 

106 db_folder: typing.Optional[str | Path] = None, 

107 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

108) -> dict[str, typing.Any]: 

109 if not db_type: 

110 db_type = getattr(define_table._db, "_dbname", None) 

111 

112 if db_type is None: 

113 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

114 

115 with TempdirOrExistingDir(db_folder) as db_folder: 

116 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

117 

118 migrator.create_table( 

119 define_table, 

120 migrate=True, 

121 fake_migrate=True, 

122 ) 

123 

124 with (Path(db_folder) / define_table._dbt).open("rb") as tfile: 

125 loaded_tables = pickle.load(tfile) # nosec B301 

126 

127 return typing.cast(dict[str, typing.Any], loaded_tables) 

128 

129 

130def generate_alter_statement( 

131 define_table_old: Table, 

132 define_table_new: Table, 

133 /, 

134 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

135 *, 

136 db_folder: str = None, 

137) -> str: 

138 if not db_type: 

139 db_type = getattr(define_table_old._db, "_dbname", None) or getattr(define_table_new._db, "_dbname", None) 

140 

141 if db_type is None: 

142 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!") 

143 

144 result = "" 

145 

146 # other db_folder than new! 

147 old_fields = sql_fields_through_tablefile(define_table_old, db_type=db_type, db_folder=None) 

148 

149 with TempdirOrExistingDir(db_folder) as db_folder: 

150 db_folder_path = Path(db_folder) 

151 new_fields = sql_fields_through_tablefile(define_table_new, db_type=db_type, db_folder=db_folder) 

152 

153 migrator = _build_dummy_migrator(db_type, db_folder=db_folder) 

154 

155 sql_log = db_folder_path / "sql.log" 

156 sql_log.unlink(missing_ok=True) # remove old crap 

157 

158 original_db_old = define_table_old._db 

159 original_db_new = define_table_new._db 

160 try: 

161 define_table_old._db = migrator.db 

162 define_table_new._db = migrator.db 

163 

164 migrator.migrate_table( 

165 define_table_new, 

166 new_fields, 

167 old_fields, 

168 new_fields, 

169 str(db_folder_path / "<deprecated>"), 

170 fake_migrate=True, 

171 ) 

172 

173 if not sql_log.exists(): 

174 # no changes! 

175 return "" 

176 

177 with sql_log.open() as f: 

178 for line in f: 

179 if not line.startswith(("ALTER", "UPDATE")): 

180 continue 

181 

182 result += line 

183 finally: 

184 define_table_new._db = original_db_new 

185 define_table_old._db = original_db_old 

186 

187 return result 

188 

189 

190def generate_sql( 

191 define_table: Table, 

192 define_table_new: Table = None, 

193 /, 

194 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, 

195 *, 

196 db_folder: str = None, 

197) -> str: 

198 if define_table_new: 

199 return generate_alter_statement(define_table, define_table_new, db_type=db_type, db_folder=db_folder) 

200 else: 

201 return generate_create_statement(define_table, db_type=db_type, db_folder=db_folder)