Coverage for src/pydal2sql/core.py: 100%
82 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 13:34 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-31 13:34 +0200
1"""
2Main functionality.
3"""
4import pickle # nosec: B403
5import typing
6from pathlib import Path
8import pydal
9from pydal.adapters import MySQL, Postgre, SQLAdapter, SQLite
10from pydal.migrator import Migrator
11from pydal.objects import Table
13from .helpers import TempdirOrExistingDir, get_typing_args
14from .types import SUPPORTED_DATABASE_TYPES, SUPPORTED_DATABASE_TYPES_WITH_ALIASES
17class DummyDAL(pydal.DAL): # type: ignore
18 def commit(self) -> None:
19 """
20 Do Nothing
21 """
24def _build_dummy_migrator(_driver_name: SUPPORTED_DATABASE_TYPES_WITH_ALIASES, /, db_folder: str) -> Migrator:
25 """
26 Create a Migrator specific to the sql dialect of _driver_name.
27 """
28 db = DummyDAL(None, migrate=False, folder=db_folder)
30 aliases = {
31 "postgresql": "psycopg2",
32 "postgres": "psycopg2",
33 "psql": "psycopg2",
34 "sqlite": "sqlite3",
35 "mysql": "pymysql",
36 }
38 driver_name = _driver_name.lower()
39 driver_name = aliases.get(driver_name, driver_name)
41 if driver_name not in get_typing_args(SUPPORTED_DATABASE_TYPES):
42 raise ValueError(
43 f"Unsupported database type {driver_name}. "
44 f"Choose one of {get_typing_args(SUPPORTED_DATABASE_TYPES_WITH_ALIASES)}"
45 )
47 adapters_per_database: dict[str, typing.Type[SQLAdapter]] = {
48 "psycopg2": Postgre,
49 "sqlite3": SQLite,
50 "pymysql": MySQL,
51 }
53 adapter = adapters_per_database[driver_name]
55 installed_driver = db._drivers_available.get(driver_name)
57 if not installed_driver: # pragma: no cover
58 raise ValueError(f"Please install the correct driver for database type {driver_name}")
60 class DummyAdaptor(SQLAdapter): # type: ignore
61 types = adapter.types
62 driver = installed_driver
63 dbengine = adapter.dbengine
65 commit_on_alter_table = True
67 adapter = DummyAdaptor(db, "", adapter_args={"driver": installed_driver})
68 db._adapter = adapter
69 return Migrator(adapter)
72def generate_create_statement(
73 define_table: Table, db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None, *, db_folder: str = None
74) -> str:
75 """
76 Given a Table object (result of `db.define_table('mytable')` or simply db.mytable) \
77 and a db type (e.g. postgres, sqlite, mysql), generate the `CREATE TABLE` SQL for that dialect.
79 If no db_type is supplied, the type is guessed from the specified table.
80 However, your db_type can differ from the current database used.
81 You can even use a dummy database to generate SQL code with:
82 `db = pydal.DAL(None, migrate=False)`
84 db_folder is the database folder where migration (`.table`) files are stored.
85 By default, a random temporary dir is created.
86 """
87 if not db_type:
88 db_type = getattr(define_table._db, "_dbname", None)
90 if db_type is None:
91 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!")
93 with TempdirOrExistingDir(db_folder) as db_folder:
94 migrator = _build_dummy_migrator(db_type, db_folder=db_folder)
96 sql: str = migrator.create_table(
97 define_table,
98 migrate=True,
99 fake_migrate=True,
100 )
101 return sql
104def sql_fields_through_tablefile(
105 define_table: Table,
106 db_folder: typing.Optional[str | Path] = None,
107 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None,
108) -> dict[str, typing.Any]:
109 if not db_type:
110 db_type = getattr(define_table._db, "_dbname", None)
112 if db_type is None:
113 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!")
115 with TempdirOrExistingDir(db_folder) as db_folder:
116 migrator = _build_dummy_migrator(db_type, db_folder=db_folder)
118 migrator.create_table(
119 define_table,
120 migrate=True,
121 fake_migrate=True,
122 )
124 with (Path(db_folder) / define_table._dbt).open("rb") as tfile:
125 loaded_tables = pickle.load(tfile) # nosec B301
127 return typing.cast(dict[str, typing.Any], loaded_tables)
130def generate_alter_statement(
131 define_table_old: Table,
132 define_table_new: Table,
133 /,
134 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None,
135 *,
136 db_folder: str = None,
137) -> str:
138 if not db_type:
139 db_type = getattr(define_table_old._db, "_dbname", None) or getattr(define_table_new._db, "_dbname", None)
141 if db_type is None:
142 raise ValueError("Database dialect could not be guessed from code; Please manually define a database type!")
144 result = ""
146 # other db_folder than new!
147 old_fields = sql_fields_through_tablefile(define_table_old, db_type=db_type, db_folder=None)
149 with TempdirOrExistingDir(db_folder) as db_folder:
150 db_folder_path = Path(db_folder)
151 new_fields = sql_fields_through_tablefile(define_table_new, db_type=db_type, db_folder=db_folder)
153 migrator = _build_dummy_migrator(db_type, db_folder=db_folder)
155 sql_log = db_folder_path / "sql.log"
156 sql_log.unlink(missing_ok=True) # remove old crap
158 original_db_old = define_table_old._db
159 original_db_new = define_table_new._db
160 try:
161 define_table_old._db = migrator.db
162 define_table_new._db = migrator.db
164 migrator.migrate_table(
165 define_table_new,
166 new_fields,
167 old_fields,
168 new_fields,
169 str(db_folder_path / "<deprecated>"),
170 fake_migrate=True,
171 )
173 if not sql_log.exists():
174 # no changes!
175 return ""
177 with sql_log.open() as f:
178 for line in f:
179 if not line.startswith(("ALTER", "UPDATE")):
180 continue
182 result += line
183 finally:
184 define_table_new._db = original_db_new
185 define_table_old._db = original_db_old
187 return result
190def generate_sql(
191 define_table: Table,
192 define_table_new: Table = None,
193 /,
194 db_type: SUPPORTED_DATABASE_TYPES_WITH_ALIASES = None,
195 *,
196 db_folder: str = None,
197) -> str:
198 if define_table_new:
199 return generate_alter_statement(define_table, define_table_new, db_type=db_type, db_folder=db_folder)
200 else:
201 return generate_create_statement(define_table, db_type=db_type, db_folder=db_folder)