1
2 r"""
3 ==========================
4 Schema module generation
5 ==========================
6
7 Schema module generation code.
8
9 :Copyright:
10
11 Copyright 2010 - 2016
12 Andr\xe9 Malo or his licensors, as applicable
13
14 :License:
15
16 Licensed under the Apache License, Version 2.0 (the "License");
17 you may not use this file except in compliance with the License.
18 You may obtain a copy of the License at
19
20 http://www.apache.org/licenses/LICENSE-2.0
21
22 Unless required by applicable law or agreed to in writing, software
23 distributed under the License is distributed on an "AS IS" BASIS,
24 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 See the License for the specific language governing permissions and
26 limitations under the License.
27
28 """
29 if __doc__:
30
31 __doc__ = __doc__.encode('ascii').decode('unicode_escape')
32 __author__ = r"Andr\xe9 Malo".encode('ascii').decode('unicode_escape')
33 __docformat__ = "restructuredtext en"
34
35 import sqlalchemy as _sa
36
37 from . import _table
38 from . import _template
39
40
42 """
43 Schema container
44
45 :CVariables:
46 `_MODULE_TPL` : ``Template``
47 Template for the module
48
49 :IVariables:
50 `_dialect` : ``str``
51 Dialect name
52
53 `_tables` : `TableCollection`
54 Table collection
55
56 `_schemas` : ``dict``
57 Schema -> module mapping
58
59 `_symbols` : `Symbols`
60 Symbol table
61
62 `_dbname` : ``str`` or ``None``
63 DB identifier
64 """
65
66 _MODULE_TPL = _template.Template('''
67 # -*- coding: ascii -*- pylint: skip-file
68 """
69 ==============================
70 SQLAlchemy schema definition
71 ==============================
72
73 SQLAlchemy schema definition%(dbspec)s.
74
75 :Warning: DO NOT EDIT, this file is generated
76 """
77 __docformat__ = "restructuredtext en"
78
79 import sqlalchemy as %(sa)s
80 from sqlalchemy.dialects import %(dialect)s as %(type)s
81 %(imports)s
82 %(meta)s = %(sa)s.MetaData()
83 %(table)s = %(sa)s.Table
84 %(column)s = %(sa)s.Column
85 %(default)s = %(sa)s.DefaultClause
86 %(lines)s
87 del %(sa)s, %(table)s, %(column)s, %(default)s, %(meta)s
88
89 # vim: nowrap tw=0
90 ''')
91
92 - def __init__(self, conn, tables, schemas, symbols, dbname=None,
93 types=None):
94 """
95 Initialization
96
97 :Parameters:
98 `conn` : ``Connection`` or ``Engine``
99 SQLAlchemy connection or engine
100
101 `tables` : ``list``
102 List of tables to reflect, (local name, table name) pairs
103
104 `schemas` : ``dict``
105 schema -> module mapping
106
107 `symbols` : `Symbols`
108 Symbol table
109
110 `dbname` : ``str``
111 Optional db identifier. Used for informational purposes. If
112 omitted or ``None``, the information just won't be emitted.
113
114 `types` : callable
115 Extra type loader. If the type reflection fails, because
116 SQLAlchemy cannot resolve it, the type loader will be called with
117 the type name, (bound) metadata and the symbol table. It is
118 responsible for modifying the symbols and imports *and* the
119 dialect's ``ischema_names``. If omitted or ``None``, the reflector
120 will always fail on unknown types.
121 """
122 metadata = _sa.MetaData(conn)
123 self._dialect = metadata.bind.dialect.name
124 self._tables = _table.TableCollection.by_names(
125 metadata, tables, schemas, symbols, types=types
126 )
127 self._schemas = schemas
128 self._symbols = symbols
129 self._dbname = dbname
130
131 - def dump(self, fp):
132 """
133 Dump schema module to fp
134
135 :Parameters:
136 `fp` : ``file``
137 File to write to
138 """
139 imports = [item % self._symbols for item in self._symbols.imports]
140 if imports:
141 imports.sort()
142 imports.append('')
143 lines = []
144
145 defines = self._symbols.types.defines
146 if defines:
147 defined = []
148 for define in defines:
149 defined.extend(define(self._dialect, self._symbols))
150 if defined:
151 lines.append('')
152 lines.append('# Custom type definitions')
153 lines.extend(defined)
154 lines.append('')
155
156 for table in self._tables:
157 if table.is_reference:
158 continue
159 if not lines:
160 lines.append('')
161 name = table.sa_table.name.encode('ascii', 'backslashescape')
162 if bytes is not str:
163 name = name.decode('ascii')
164 lines.append('# Table "%s"' % (name,))
165 lines.append('%s = %r' % (table.varname, table))
166 lines.append('')
167 lines.append('')
168
169 param = dict(((str(key), value) for key, value in self._symbols),
170 dbspec=" for %s" % self._dbname if self._dbname else "",
171 dialect=self._dialect,
172 imports='\n'.join(imports),
173 lines='\n'.join(lines))
174 fp.write(self._MODULE_TPL.expand(**param))
175 fp.write('\n')
176