Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from sqlalchemy import cast 

2from sqlalchemy import CheckConstraint 

3from sqlalchemy import Column 

4from sqlalchemy import ForeignKeyConstraint 

5from sqlalchemy import Index 

6from sqlalchemy import MetaData 

7from sqlalchemy import PrimaryKeyConstraint 

8from sqlalchemy import schema as sql_schema 

9from sqlalchemy import select 

10from sqlalchemy import Table 

11from sqlalchemy import types as sqltypes 

12from sqlalchemy.events import SchemaEventTarget 

13from sqlalchemy.util import OrderedDict 

14from sqlalchemy.util import topological 

15 

16from ..util import exc 

17from ..util.sqla_compat import _columns_for_constraint 

18from ..util.sqla_compat import _fk_is_self_referential 

19from ..util.sqla_compat import _is_type_bound 

20from ..util.sqla_compat import _remove_column_from_collection 

21 

22 

23class BatchOperationsImpl(object): 

24 def __init__( 

25 self, 

26 operations, 

27 table_name, 

28 schema, 

29 recreate, 

30 copy_from, 

31 table_args, 

32 table_kwargs, 

33 reflect_args, 

34 reflect_kwargs, 

35 naming_convention, 

36 partial_reordering, 

37 ): 

38 self.operations = operations 

39 self.table_name = table_name 

40 self.schema = schema 

41 if recreate not in ("auto", "always", "never"): 

42 raise ValueError( 

43 "recreate may be one of 'auto', 'always', or 'never'." 

44 ) 

45 self.recreate = recreate 

46 self.copy_from = copy_from 

47 self.table_args = table_args 

48 self.table_kwargs = dict(table_kwargs) 

49 self.reflect_args = reflect_args 

50 self.reflect_kwargs = dict(reflect_kwargs) 

51 self.reflect_kwargs.setdefault( 

52 "listeners", list(self.reflect_kwargs.get("listeners", ())) 

53 ) 

54 self.reflect_kwargs["listeners"].append( 

55 ("column_reflect", operations.impl.autogen_column_reflect) 

56 ) 

57 self.naming_convention = naming_convention 

58 self.partial_reordering = partial_reordering 

59 self.batch = [] 

60 

61 @property 

62 def dialect(self): 

63 return self.operations.impl.dialect 

64 

65 @property 

66 def impl(self): 

67 return self.operations.impl 

68 

69 def _should_recreate(self): 

70 if self.recreate == "auto": 

71 return self.operations.impl.requires_recreate_in_batch(self) 

72 elif self.recreate == "always": 

73 return True 

74 else: 

75 return False 

76 

77 def flush(self): 

78 should_recreate = self._should_recreate() 

79 

80 if not should_recreate: 

81 for opname, arg, kw in self.batch: 

82 fn = getattr(self.operations.impl, opname) 

83 fn(*arg, **kw) 

84 else: 

85 if self.naming_convention: 

86 m1 = MetaData(naming_convention=self.naming_convention) 

87 else: 

88 m1 = MetaData() 

89 

90 if self.copy_from is not None: 

91 existing_table = self.copy_from 

92 reflected = False 

93 else: 

94 existing_table = Table( 

95 self.table_name, 

96 m1, 

97 schema=self.schema, 

98 autoload=True, 

99 autoload_with=self.operations.get_bind(), 

100 *self.reflect_args, 

101 **self.reflect_kwargs 

102 ) 

103 reflected = True 

104 

105 batch_impl = ApplyBatchImpl( 

106 existing_table, 

107 self.table_args, 

108 self.table_kwargs, 

109 reflected, 

110 partial_reordering=self.partial_reordering, 

111 ) 

112 for opname, arg, kw in self.batch: 

113 fn = getattr(batch_impl, opname) 

114 fn(*arg, **kw) 

115 

116 batch_impl._create(self.impl) 

117 

118 def alter_column(self, *arg, **kw): 

119 self.batch.append(("alter_column", arg, kw)) 

120 

121 def add_column(self, *arg, **kw): 

122 if ( 

123 "insert_before" in kw or "insert_after" in kw 

124 ) and not self._should_recreate(): 

125 raise exc.CommandError( 

126 "Can't specify insert_before or insert_after when using " 

127 "ALTER; please specify recreate='always'" 

128 ) 

129 self.batch.append(("add_column", arg, kw)) 

130 

131 def drop_column(self, *arg, **kw): 

132 self.batch.append(("drop_column", arg, kw)) 

133 

134 def add_constraint(self, const): 

135 self.batch.append(("add_constraint", (const,), {})) 

136 

137 def drop_constraint(self, const): 

138 self.batch.append(("drop_constraint", (const,), {})) 

139 

140 def rename_table(self, *arg, **kw): 

141 self.batch.append(("rename_table", arg, kw)) 

142 

143 def create_index(self, idx): 

144 self.batch.append(("create_index", (idx,), {})) 

145 

146 def drop_index(self, idx): 

147 self.batch.append(("drop_index", (idx,), {})) 

148 

149 def create_table(self, table): 

150 raise NotImplementedError("Can't create table in batch mode") 

151 

152 def drop_table(self, table): 

153 raise NotImplementedError("Can't drop table in batch mode") 

154 

155 

156class ApplyBatchImpl(object): 

157 def __init__( 

158 self, table, table_args, table_kwargs, reflected, partial_reordering=() 

159 ): 

160 self.table = table # this is a Table object 

161 self.table_args = table_args 

162 self.table_kwargs = table_kwargs 

163 self.temp_table_name = self._calc_temp_name(table.name) 

164 self.new_table = None 

165 

166 self.partial_reordering = partial_reordering # tuple of tuples 

167 self.add_col_ordering = () # tuple of tuples 

168 

169 self.column_transfers = OrderedDict( 

170 (c.name, {"expr": c}) for c in self.table.c 

171 ) 

172 self.existing_ordering = list(self.column_transfers) 

173 

174 self.reflected = reflected 

175 self._grab_table_elements() 

176 

177 @classmethod 

178 def _calc_temp_name(cls, tablename): 

179 return ("_alembic_tmp_%s" % tablename)[0:50] 

180 

181 def _grab_table_elements(self): 

182 schema = self.table.schema 

183 self.columns = OrderedDict() 

184 for c in self.table.c: 

185 c_copy = c.copy(schema=schema) 

186 c_copy.unique = c_copy.index = False 

187 # ensure that the type object was copied, 

188 # as we may need to modify it in-place 

189 if isinstance(c.type, SchemaEventTarget): 

190 assert c_copy.type is not c.type 

191 self.columns[c.name] = c_copy 

192 self.named_constraints = {} 

193 self.unnamed_constraints = [] 

194 self.indexes = {} 

195 self.new_indexes = {} 

196 for const in self.table.constraints: 

197 if _is_type_bound(const): 

198 continue 

199 elif self.reflected and isinstance(const, CheckConstraint): 

200 # TODO: we are skipping reflected CheckConstraint because 

201 # we have no way to determine _is_type_bound() for these. 

202 pass 

203 elif const.name: 

204 self.named_constraints[const.name] = const 

205 else: 

206 self.unnamed_constraints.append(const) 

207 

208 for idx in self.table.indexes: 

209 self.indexes[idx.name] = idx 

210 

211 for k in self.table.kwargs: 

212 self.table_kwargs.setdefault(k, self.table.kwargs[k]) 

213 

214 def _adjust_self_columns_for_partial_reordering(self): 

215 pairs = set() 

216 

217 col_by_idx = list(self.columns) 

218 

219 if self.partial_reordering: 

220 for tuple_ in self.partial_reordering: 

221 for index, elem in enumerate(tuple_): 

222 if index > 0: 

223 pairs.add((tuple_[index - 1], elem)) 

224 else: 

225 for index, elem in enumerate(self.existing_ordering): 

226 if index > 0: 

227 pairs.add((col_by_idx[index - 1], elem)) 

228 

229 pairs.update(self.add_col_ordering) 

230 

231 # this can happen if some columns were dropped and not removed 

232 # from existing_ordering. this should be prevented already, but 

233 # conservatively making sure this didn't happen 

234 pairs = [p for p in pairs if p[0] != p[1]] 

235 

236 sorted_ = list( 

237 topological.sort(pairs, col_by_idx, deterministic_order=True) 

238 ) 

239 self.columns = OrderedDict((k, self.columns[k]) for k in sorted_) 

240 self.column_transfers = OrderedDict( 

241 (k, self.column_transfers[k]) for k in sorted_ 

242 ) 

243 

244 def _transfer_elements_to_new_table(self): 

245 assert self.new_table is None, "Can only create new table once" 

246 

247 m = MetaData() 

248 schema = self.table.schema 

249 

250 if self.partial_reordering or self.add_col_ordering: 

251 self._adjust_self_columns_for_partial_reordering() 

252 

253 self.new_table = new_table = Table( 

254 self.temp_table_name, 

255 m, 

256 *(list(self.columns.values()) + list(self.table_args)), 

257 schema=schema, 

258 **self.table_kwargs 

259 ) 

260 

261 for const in ( 

262 list(self.named_constraints.values()) + self.unnamed_constraints 

263 ): 

264 

265 const_columns = set( 

266 [c.key for c in _columns_for_constraint(const)] 

267 ) 

268 

269 if not const_columns.issubset(self.column_transfers): 

270 continue 

271 

272 if isinstance(const, ForeignKeyConstraint): 

273 if _fk_is_self_referential(const): 

274 # for self-referential constraint, refer to the 

275 # *original* table name, and not _alembic_batch_temp. 

276 # This is consistent with how we're handling 

277 # FK constraints from other tables; we assume SQLite 

278 # no foreign keys just keeps the names unchanged, so 

279 # when we rename back, they match again. 

280 const_copy = const.copy( 

281 schema=schema, target_table=self.table 

282 ) 

283 else: 

284 # "target_table" for ForeignKeyConstraint.copy() is 

285 # only used if the FK is detected as being 

286 # self-referential, which we are handling above. 

287 const_copy = const.copy(schema=schema) 

288 else: 

289 const_copy = const.copy(schema=schema, target_table=new_table) 

290 if isinstance(const, ForeignKeyConstraint): 

291 self._setup_referent(m, const) 

292 new_table.append_constraint(const_copy) 

293 

294 def _gather_indexes_from_both_tables(self): 

295 idx = [] 

296 idx.extend(self.indexes.values()) 

297 for index in self.new_indexes.values(): 

298 idx.append( 

299 Index( 

300 index.name, 

301 unique=index.unique, 

302 *[self.new_table.c[col] for col in index.columns.keys()], 

303 **index.kwargs 

304 ) 

305 ) 

306 return idx 

307 

308 def _setup_referent(self, metadata, constraint): 

309 spec = constraint.elements[0]._get_colspec() 

310 parts = spec.split(".") 

311 tname = parts[-2] 

312 if len(parts) == 3: 

313 referent_schema = parts[0] 

314 else: 

315 referent_schema = None 

316 

317 if tname != self.temp_table_name: 

318 key = sql_schema._get_table_key(tname, referent_schema) 

319 if key in metadata.tables: 

320 t = metadata.tables[key] 

321 for elem in constraint.elements: 

322 colname = elem._get_colspec().split(".")[-1] 

323 if not t.c.contains_column(colname): 

324 t.append_column(Column(colname, sqltypes.NULLTYPE)) 

325 else: 

326 Table( 

327 tname, 

328 metadata, 

329 *[ 

330 Column(n, sqltypes.NULLTYPE) 

331 for n in [ 

332 elem._get_colspec().split(".")[-1] 

333 for elem in constraint.elements 

334 ] 

335 ], 

336 schema=referent_schema 

337 ) 

338 

339 def _create(self, op_impl): 

340 self._transfer_elements_to_new_table() 

341 

342 op_impl.prep_table_for_batch(self.table) 

343 op_impl.create_table(self.new_table) 

344 

345 try: 

346 op_impl._exec( 

347 self.new_table.insert(inline=True).from_select( 

348 list( 

349 k 

350 for k, transfer in self.column_transfers.items() 

351 if "expr" in transfer 

352 ), 

353 select( 

354 [ 

355 transfer["expr"] 

356 for transfer in self.column_transfers.values() 

357 if "expr" in transfer 

358 ] 

359 ), 

360 ) 

361 ) 

362 op_impl.drop_table(self.table) 

363 except: 

364 op_impl.drop_table(self.new_table) 

365 raise 

366 else: 

367 op_impl.rename_table( 

368 self.temp_table_name, self.table.name, schema=self.table.schema 

369 ) 

370 self.new_table.name = self.table.name 

371 try: 

372 for idx in self._gather_indexes_from_both_tables(): 

373 op_impl.create_index(idx) 

374 finally: 

375 self.new_table.name = self.temp_table_name 

376 

377 def alter_column( 

378 self, 

379 table_name, 

380 column_name, 

381 nullable=None, 

382 server_default=False, 

383 name=None, 

384 type_=None, 

385 autoincrement=None, 

386 **kw 

387 ): 

388 existing = self.columns[column_name] 

389 existing_transfer = self.column_transfers[column_name] 

390 if name is not None and name != column_name: 

391 # note that we don't change '.key' - we keep referring 

392 # to the renamed column by its old key in _create(). neat! 

393 existing.name = name 

394 existing_transfer["name"] = name 

395 

396 if type_ is not None: 

397 type_ = sqltypes.to_instance(type_) 

398 # old type is being discarded so turn off eventing 

399 # rules. Alternatively we can 

400 # erase the events set up by this type, but this is simpler. 

401 # we also ignore the drop_constraint that will come here from 

402 # Operations.implementation_for(alter_column) 

403 if isinstance(existing.type, SchemaEventTarget): 

404 existing.type._create_events = ( 

405 existing.type.create_constraint 

406 ) = False 

407 

408 if existing.type._type_affinity is not type_._type_affinity: 

409 existing_transfer["expr"] = cast( 

410 existing_transfer["expr"], type_ 

411 ) 

412 

413 existing.type = type_ 

414 

415 # we *dont* however set events for the new type, because 

416 # alter_column is invoked from 

417 # Operations.implementation_for(alter_column) which already 

418 # will emit an add_constraint() 

419 

420 if nullable is not None: 

421 existing.nullable = nullable 

422 if server_default is not False: 

423 if server_default is None: 

424 existing.server_default = None 

425 else: 

426 sql_schema.DefaultClause(server_default)._set_parent(existing) 

427 if autoincrement is not None: 

428 existing.autoincrement = bool(autoincrement) 

429 

430 def _setup_dependencies_for_add_column( 

431 self, colname, insert_before, insert_after 

432 ): 

433 index_cols = self.existing_ordering 

434 col_indexes = {name: i for i, name in enumerate(index_cols)} 

435 

436 if not self.partial_reordering: 

437 if insert_after: 

438 if not insert_before: 

439 if insert_after in col_indexes: 

440 # insert after an existing column 

441 idx = col_indexes[insert_after] + 1 

442 if idx < len(index_cols): 

443 insert_before = index_cols[idx] 

444 else: 

445 # insert after a column that is also new 

446 insert_before = dict(self.add_col_ordering)[ 

447 insert_after 

448 ] 

449 if insert_before: 

450 if not insert_after: 

451 if insert_before in col_indexes: 

452 # insert before an existing column 

453 idx = col_indexes[insert_before] - 1 

454 if idx >= 0: 

455 insert_after = index_cols[idx] 

456 else: 

457 # insert before a column that is also new 

458 insert_after = dict( 

459 (b, a) for a, b in self.add_col_ordering 

460 )[insert_before] 

461 

462 if insert_before: 

463 self.add_col_ordering += ((colname, insert_before),) 

464 if insert_after: 

465 self.add_col_ordering += ((insert_after, colname),) 

466 

467 if ( 

468 not self.partial_reordering 

469 and not insert_before 

470 and not insert_after 

471 and col_indexes 

472 ): 

473 self.add_col_ordering += ((index_cols[-1], colname),) 

474 

475 def add_column( 

476 self, table_name, column, insert_before=None, insert_after=None, **kw 

477 ): 

478 self._setup_dependencies_for_add_column( 

479 column.name, insert_before, insert_after 

480 ) 

481 # we copy the column because operations.add_column() 

482 # gives us a Column that is part of a Table already. 

483 self.columns[column.name] = column.copy(schema=self.table.schema) 

484 self.column_transfers[column.name] = {} 

485 

486 def drop_column(self, table_name, column, **kw): 

487 if column.name in self.table.primary_key.columns: 

488 _remove_column_from_collection( 

489 self.table.primary_key.columns, column 

490 ) 

491 del self.columns[column.name] 

492 del self.column_transfers[column.name] 

493 self.existing_ordering.remove(column.name) 

494 

495 def add_constraint(self, const): 

496 if not const.name: 

497 raise ValueError("Constraint must have a name") 

498 if isinstance(const, sql_schema.PrimaryKeyConstraint): 

499 if self.table.primary_key in self.unnamed_constraints: 

500 self.unnamed_constraints.remove(self.table.primary_key) 

501 

502 self.named_constraints[const.name] = const 

503 

504 def drop_constraint(self, const): 

505 if not const.name: 

506 raise ValueError("Constraint must have a name") 

507 try: 

508 const = self.named_constraints.pop(const.name) 

509 except KeyError: 

510 if _is_type_bound(const): 

511 # type-bound constraints are only included in the new 

512 # table via their type object in any case, so ignore the 

513 # drop_constraint() that comes here via the 

514 # Operations.implementation_for(alter_column) 

515 return 

516 raise ValueError("No such constraint: '%s'" % const.name) 

517 else: 

518 if isinstance(const, PrimaryKeyConstraint): 

519 for col in const.columns: 

520 self.columns[col.name].primary_key = False 

521 

522 def create_index(self, idx): 

523 self.new_indexes[idx.name] = idx 

524 

525 def drop_index(self, idx): 

526 try: 

527 del self.indexes[idx.name] 

528 except KeyError: 

529 raise ValueError("No such index: '%s'" % idx.name) 

530 

531 def rename_table(self, *arg, **kw): 

532 raise NotImplementedError("TODO")