Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# sql/util.py 

2# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: http://www.opensource.org/licenses/mit-license.php 

7 

8"""High level utilities which build upon other modules here. 

9 

10""" 

11 

12from collections import deque 

13from itertools import chain 

14 

15from . import operators 

16from . import visitors 

17from .annotation import _deep_annotate # noqa 

18from .annotation import _deep_deannotate # noqa 

19from .annotation import _shallow_annotate # noqa 

20from .base import _from_objects 

21from .base import ColumnSet 

22from .ddl import sort_tables # noqa 

23from .elements import _expand_cloned 

24from .elements import _find_columns # noqa 

25from .elements import _label_reference 

26from .elements import _textual_label_reference 

27from .elements import BindParameter 

28from .elements import ColumnClause 

29from .elements import ColumnElement 

30from .elements import Null 

31from .elements import UnaryExpression 

32from .schema import Column 

33from .selectable import Alias 

34from .selectable import FromClause 

35from .selectable import FromGrouping 

36from .selectable import Join 

37from .selectable import ScalarSelect 

38from .selectable import SelectBase 

39from .selectable import TableClause 

40from .. import exc 

41from .. import util 

42 

43 

44join_condition = util.langhelpers.public_factory( 

45 Join._join_condition, ".sql.util.join_condition" 

46) 

47 

48 

49def find_join_source(clauses, join_to): 

50 """Given a list of FROM clauses and a selectable, 

51 return the first index and element from the list of 

52 clauses which can be joined against the selectable. returns 

53 None, None if no match is found. 

54 

55 e.g.:: 

56 

57 clause1 = table1.join(table2) 

58 clause2 = table4.join(table5) 

59 

60 join_to = table2.join(table3) 

61 

62 find_join_source([clause1, clause2], join_to) == clause1 

63 

64 """ 

65 

66 selectables = list(_from_objects(join_to)) 

67 idx = [] 

68 for i, f in enumerate(clauses): 

69 for s in selectables: 

70 if f.is_derived_from(s): 

71 idx.append(i) 

72 return idx 

73 

74 

75def find_left_clause_that_matches_given(clauses, join_from): 

76 """Given a list of FROM clauses and a selectable, 

77 return the indexes from the list of 

78 clauses which is derived from the selectable. 

79 

80 """ 

81 

82 selectables = list(_from_objects(join_from)) 

83 liberal_idx = [] 

84 for i, f in enumerate(clauses): 

85 for s in selectables: 

86 # basic check, if f is derived from s. 

87 # this can be joins containing a table, or an aliased table 

88 # or select statement matching to a table. This check 

89 # will match a table to a selectable that is adapted from 

90 # that table. With Query, this suits the case where a join 

91 # is being made to an adapted entity 

92 if f.is_derived_from(s): 

93 liberal_idx.append(i) 

94 break 

95 

96 # in an extremely small set of use cases, a join is being made where 

97 # there are multiple FROM clauses where our target table is represented 

98 # in more than one, such as embedded or similar. in this case, do 

99 # another pass where we try to get a more exact match where we aren't 

100 # looking at adaption relationships. 

101 if len(liberal_idx) > 1: 

102 conservative_idx = [] 

103 for idx in liberal_idx: 

104 f = clauses[idx] 

105 for s in selectables: 

106 if set(surface_selectables(f)).intersection( 

107 surface_selectables(s) 

108 ): 

109 conservative_idx.append(idx) 

110 break 

111 if conservative_idx: 

112 return conservative_idx 

113 

114 return liberal_idx 

115 

116 

117def find_left_clause_to_join_from(clauses, join_to, onclause): 

118 """Given a list of FROM clauses, a selectable, 

119 and optional ON clause, return a list of integer indexes from the 

120 clauses list indicating the clauses that can be joined from. 

121 

122 The presence of an "onclause" indicates that at least one clause can 

123 definitely be joined from; if the list of clauses is of length one 

124 and the onclause is given, returns that index. If the list of clauses 

125 is more than length one, and the onclause is given, attempts to locate 

126 which clauses contain the same columns. 

127 

128 """ 

129 idx = [] 

130 selectables = set(_from_objects(join_to)) 

131 

132 # if we are given more than one target clause to join 

133 # from, use the onclause to provide a more specific answer. 

134 # otherwise, don't try to limit, after all, "ON TRUE" is a valid 

135 # on clause 

136 if len(clauses) > 1 and onclause is not None: 

137 resolve_ambiguity = True 

138 cols_in_onclause = _find_columns(onclause) 

139 else: 

140 resolve_ambiguity = False 

141 cols_in_onclause = None 

142 

143 for i, f in enumerate(clauses): 

144 for s in selectables.difference([f]): 

145 if resolve_ambiguity: 

146 if set(f.c).union(s.c).issuperset(cols_in_onclause): 

147 idx.append(i) 

148 break 

149 elif Join._can_join(f, s) or onclause is not None: 

150 idx.append(i) 

151 break 

152 

153 if len(idx) > 1: 

154 # this is the same "hide froms" logic from 

155 # Selectable._get_display_froms 

156 toremove = set( 

157 chain(*[_expand_cloned(f._hide_froms) for f in clauses]) 

158 ) 

159 idx = [i for i in idx if clauses[i] not in toremove] 

160 

161 # onclause was given and none of them resolved, so assume 

162 # all indexes can match 

163 if not idx and onclause is not None: 

164 return range(len(clauses)) 

165 else: 

166 return idx 

167 

168 

169def visit_binary_product(fn, expr): 

170 """Produce a traversal of the given expression, delivering 

171 column comparisons to the given function. 

172 

173 The function is of the form:: 

174 

175 def my_fn(binary, left, right) 

176 

177 For each binary expression located which has a 

178 comparison operator, the product of "left" and 

179 "right" will be delivered to that function, 

180 in terms of that binary. 

181 

182 Hence an expression like:: 

183 

184 and_( 

185 (a + b) == q + func.sum(e + f), 

186 j == r 

187 ) 

188 

189 would have the traversal:: 

190 

191 a <eq> q 

192 a <eq> e 

193 a <eq> f 

194 b <eq> q 

195 b <eq> e 

196 b <eq> f 

197 j <eq> r 

198 

199 That is, every combination of "left" and 

200 "right" that doesn't further contain 

201 a binary comparison is passed as pairs. 

202 

203 """ 

204 stack = [] 

205 

206 def visit(element): 

207 if isinstance(element, ScalarSelect): 

208 # we don't want to dig into correlated subqueries, 

209 # those are just column elements by themselves 

210 yield element 

211 elif element.__visit_name__ == "binary" and operators.is_comparison( 

212 element.operator 

213 ): 

214 stack.insert(0, element) 

215 for l in visit(element.left): 

216 for r in visit(element.right): 

217 fn(stack[0], l, r) 

218 stack.pop(0) 

219 for elem in element.get_children(): 

220 visit(elem) 

221 else: 

222 if isinstance(element, ColumnClause): 

223 yield element 

224 for elem in element.get_children(): 

225 for e in visit(elem): 

226 yield e 

227 

228 list(visit(expr)) 

229 visit = None # remove gc cycles 

230 

231 

232def find_tables( 

233 clause, 

234 check_columns=False, 

235 include_aliases=False, 

236 include_joins=False, 

237 include_selects=False, 

238 include_crud=False, 

239): 

240 """locate Table objects within the given expression.""" 

241 

242 tables = [] 

243 _visitors = {} 

244 

245 if include_selects: 

246 _visitors["select"] = _visitors["compound_select"] = tables.append 

247 

248 if include_joins: 

249 _visitors["join"] = tables.append 

250 

251 if include_aliases: 

252 _visitors["alias"] = tables.append 

253 

254 if include_crud: 

255 _visitors["insert"] = _visitors["update"] = _visitors[ 

256 "delete" 

257 ] = lambda ent: tables.append(ent.table) 

258 

259 if check_columns: 

260 

261 def visit_column(column): 

262 tables.append(column.table) 

263 

264 _visitors["column"] = visit_column 

265 

266 _visitors["table"] = tables.append 

267 

268 visitors.traverse(clause, {"column_collections": False}, _visitors) 

269 return tables 

270 

271 

272def unwrap_order_by(clause): 

273 """Break up an 'order by' expression into individual column-expressions, 

274 without DESC/ASC/NULLS FIRST/NULLS LAST""" 

275 

276 cols = util.column_set() 

277 result = [] 

278 stack = deque([clause]) 

279 while stack: 

280 t = stack.popleft() 

281 if isinstance(t, ColumnElement) and ( 

282 not isinstance(t, UnaryExpression) 

283 or not operators.is_ordering_modifier(t.modifier) 

284 ): 

285 if isinstance(t, _label_reference): 

286 t = t.element 

287 if isinstance(t, (_textual_label_reference)): 

288 continue 

289 if t not in cols: 

290 cols.add(t) 

291 result.append(t) 

292 else: 

293 for c in t.get_children(): 

294 stack.append(c) 

295 return result 

296 

297 

298def unwrap_label_reference(element): 

299 def replace(elem): 

300 if isinstance(elem, (_label_reference, _textual_label_reference)): 

301 return elem.element 

302 

303 return visitors.replacement_traverse(element, {}, replace) 

304 

305 

306def expand_column_list_from_order_by(collist, order_by): 

307 """Given the columns clause and ORDER BY of a selectable, 

308 return a list of column expressions that can be added to the collist 

309 corresponding to the ORDER BY, without repeating those already 

310 in the collist. 

311 

312 """ 

313 cols_already_present = set( 

314 [ 

315 col.element if col._order_by_label_element is not None else col 

316 for col in collist 

317 ] 

318 ) 

319 

320 return [ 

321 col 

322 for col in chain(*[unwrap_order_by(o) for o in order_by]) 

323 if col not in cols_already_present 

324 ] 

325 

326 

327def clause_is_present(clause, search): 

328 """Given a target clause and a second to search within, return True 

329 if the target is plainly present in the search without any 

330 subqueries or aliases involved. 

331 

332 Basically descends through Joins. 

333 

334 """ 

335 

336 for elem in surface_selectables(search): 

337 if clause == elem: # use == here so that Annotated's compare 

338 return True 

339 else: 

340 return False 

341 

342 

343def surface_selectables(clause): 

344 stack = [clause] 

345 while stack: 

346 elem = stack.pop() 

347 yield elem 

348 if isinstance(elem, Join): 

349 stack.extend((elem.left, elem.right)) 

350 elif isinstance(elem, FromGrouping): 

351 stack.append(elem.element) 

352 

353 

354def surface_selectables_only(clause): 

355 stack = [clause] 

356 while stack: 

357 elem = stack.pop() 

358 if isinstance(elem, (TableClause, Alias)): 

359 yield elem 

360 if isinstance(elem, Join): 

361 stack.extend((elem.left, elem.right)) 

362 elif isinstance(elem, FromGrouping): 

363 stack.append(elem.element) 

364 elif isinstance(elem, ColumnClause): 

365 stack.append(elem.table) 

366 

367 

368def surface_column_elements(clause, include_scalar_selects=True): 

369 """traverse and yield only outer-exposed column elements, such as would 

370 be addressable in the WHERE clause of a SELECT if this element were 

371 in the columns clause.""" 

372 

373 filter_ = (FromGrouping,) 

374 if not include_scalar_selects: 

375 filter_ += (SelectBase,) 

376 

377 stack = deque([clause]) 

378 while stack: 

379 elem = stack.popleft() 

380 yield elem 

381 for sub in elem.get_children(): 

382 if isinstance(sub, filter_): 

383 continue 

384 stack.append(sub) 

385 

386 

387def selectables_overlap(left, right): 

388 """Return True if left/right have some overlapping selectable""" 

389 

390 return bool( 

391 set(surface_selectables(left)).intersection(surface_selectables(right)) 

392 ) 

393 

394 

395def bind_values(clause): 

396 """Return an ordered list of "bound" values in the given clause. 

397 

398 E.g.:: 

399 

400 >>> expr = and_( 

401 ... table.c.foo==5, table.c.foo==7 

402 ... ) 

403 >>> bind_values(expr) 

404 [5, 7] 

405 """ 

406 

407 v = [] 

408 

409 def visit_bindparam(bind): 

410 v.append(bind.effective_value) 

411 

412 visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) 

413 return v 

414 

415 

416def _quote_ddl_expr(element): 

417 if isinstance(element, util.string_types): 

418 element = element.replace("'", "''") 

419 return "'%s'" % element 

420 else: 

421 return repr(element) 

422 

423 

424class _repr_base(object): 

425 _LIST = 0 

426 _TUPLE = 1 

427 _DICT = 2 

428 

429 __slots__ = ("max_chars",) 

430 

431 def trunc(self, value): 

432 rep = repr(value) 

433 lenrep = len(rep) 

434 if lenrep > self.max_chars: 

435 segment_length = self.max_chars // 2 

436 rep = ( 

437 rep[0:segment_length] 

438 + ( 

439 " ... (%d characters truncated) ... " 

440 % (lenrep - self.max_chars) 

441 ) 

442 + rep[-segment_length:] 

443 ) 

444 return rep 

445 

446 

447class _repr_row(_repr_base): 

448 """Provide a string view of a row.""" 

449 

450 __slots__ = ("row",) 

451 

452 def __init__(self, row, max_chars=300): 

453 self.row = row 

454 self.max_chars = max_chars 

455 

456 def __repr__(self): 

457 trunc = self.trunc 

458 return "(%s%s)" % ( 

459 ", ".join(trunc(value) for value in self.row), 

460 "," if len(self.row) == 1 else "", 

461 ) 

462 

463 

464class _repr_params(_repr_base): 

465 """Provide a string view of bound parameters. 

466 

467 Truncates display to a given numnber of 'multi' parameter sets, 

468 as well as long values to a given number of characters. 

469 

470 """ 

471 

472 __slots__ = "params", "batches", "ismulti" 

473 

474 def __init__(self, params, batches, max_chars=300, ismulti=None): 

475 self.params = params 

476 self.ismulti = ismulti 

477 self.batches = batches 

478 self.max_chars = max_chars 

479 

480 def __repr__(self): 

481 if self.ismulti is None: 

482 return self.trunc(self.params) 

483 

484 if isinstance(self.params, list): 

485 typ = self._LIST 

486 

487 elif isinstance(self.params, tuple): 

488 typ = self._TUPLE 

489 elif isinstance(self.params, dict): 

490 typ = self._DICT 

491 else: 

492 return self.trunc(self.params) 

493 

494 if self.ismulti and len(self.params) > self.batches: 

495 msg = " ... displaying %i of %i total bound parameter sets ... " 

496 return " ".join( 

497 ( 

498 self._repr_multi(self.params[: self.batches - 2], typ)[ 

499 0:-1 

500 ], 

501 msg % (self.batches, len(self.params)), 

502 self._repr_multi(self.params[-2:], typ)[1:], 

503 ) 

504 ) 

505 elif self.ismulti: 

506 return self._repr_multi(self.params, typ) 

507 else: 

508 return self._repr_params(self.params, typ) 

509 

510 def _repr_multi(self, multi_params, typ): 

511 if multi_params: 

512 if isinstance(multi_params[0], list): 

513 elem_type = self._LIST 

514 elif isinstance(multi_params[0], tuple): 

515 elem_type = self._TUPLE 

516 elif isinstance(multi_params[0], dict): 

517 elem_type = self._DICT 

518 else: 

519 assert False, "Unknown parameter type %s" % ( 

520 type(multi_params[0]) 

521 ) 

522 

523 elements = ", ".join( 

524 self._repr_params(params, elem_type) for params in multi_params 

525 ) 

526 else: 

527 elements = "" 

528 

529 if typ == self._LIST: 

530 return "[%s]" % elements 

531 else: 

532 return "(%s)" % elements 

533 

534 def _repr_params(self, params, typ): 

535 trunc = self.trunc 

536 if typ is self._DICT: 

537 return "{%s}" % ( 

538 ", ".join( 

539 "%r: %s" % (key, trunc(value)) 

540 for key, value in params.items() 

541 ) 

542 ) 

543 elif typ is self._TUPLE: 

544 return "(%s%s)" % ( 

545 ", ".join(trunc(value) for value in params), 

546 "," if len(params) == 1 else "", 

547 ) 

548 else: 

549 return "[%s]" % (", ".join(trunc(value) for value in params)) 

550 

551 

552def adapt_criterion_to_null(crit, nulls): 

553 """given criterion containing bind params, convert selected elements 

554 to IS NULL. 

555 

556 """ 

557 

558 def visit_binary(binary): 

559 if ( 

560 isinstance(binary.left, BindParameter) 

561 and binary.left._identifying_key in nulls 

562 ): 

563 # reverse order if the NULL is on the left side 

564 binary.left = binary.right 

565 binary.right = Null() 

566 binary.operator = operators.is_ 

567 binary.negate = operators.isnot 

568 elif ( 

569 isinstance(binary.right, BindParameter) 

570 and binary.right._identifying_key in nulls 

571 ): 

572 binary.right = Null() 

573 binary.operator = operators.is_ 

574 binary.negate = operators.isnot 

575 

576 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) 

577 

578 

579def splice_joins(left, right, stop_on=None): 

580 if left is None: 

581 return right 

582 

583 stack = [(right, None)] 

584 

585 adapter = ClauseAdapter(left) 

586 ret = None 

587 while stack: 

588 (right, prevright) = stack.pop() 

589 if isinstance(right, Join) and right is not stop_on: 

590 right = right._clone() 

591 right._reset_exported() 

592 right.onclause = adapter.traverse(right.onclause) 

593 stack.append((right.left, right)) 

594 else: 

595 right = adapter.traverse(right) 

596 if prevright is not None: 

597 prevright.left = right 

598 if ret is None: 

599 ret = right 

600 

601 return ret 

602 

603 

604def reduce_columns(columns, *clauses, **kw): 

605 r"""given a list of columns, return a 'reduced' set based on natural 

606 equivalents. 

607 

608 the set is reduced to the smallest list of columns which have no natural 

609 equivalent present in the list. A "natural equivalent" means that two 

610 columns will ultimately represent the same value because they are related 

611 by a foreign key. 

612 

613 \*clauses is an optional list of join clauses which will be traversed 

614 to further identify columns that are "equivalent". 

615 

616 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys 

617 whose tables are not yet configured, or columns that aren't yet present. 

618 

619 This function is primarily used to determine the most minimal "primary 

620 key" from a selectable, by reducing the set of primary key columns present 

621 in the selectable to just those that are not repeated. 

622 

623 """ 

624 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) 

625 only_synonyms = kw.pop("only_synonyms", False) 

626 

627 columns = util.ordered_column_set(columns) 

628 

629 omit = util.column_set() 

630 for col in columns: 

631 for fk in chain(*[c.foreign_keys for c in col.proxy_set]): 

632 for c in columns: 

633 if c is col: 

634 continue 

635 try: 

636 fk_col = fk.column 

637 except exc.NoReferencedColumnError: 

638 # TODO: add specific coverage here 

639 # to test/sql/test_selectable ReduceTest 

640 if ignore_nonexistent_tables: 

641 continue 

642 else: 

643 raise 

644 except exc.NoReferencedTableError: 

645 # TODO: add specific coverage here 

646 # to test/sql/test_selectable ReduceTest 

647 if ignore_nonexistent_tables: 

648 continue 

649 else: 

650 raise 

651 if fk_col.shares_lineage(c) and ( 

652 not only_synonyms or c.name == col.name 

653 ): 

654 omit.add(col) 

655 break 

656 

657 if clauses: 

658 

659 def visit_binary(binary): 

660 if binary.operator == operators.eq: 

661 cols = util.column_set( 

662 chain(*[c.proxy_set for c in columns.difference(omit)]) 

663 ) 

664 if binary.left in cols and binary.right in cols: 

665 for c in reversed(columns): 

666 if c.shares_lineage(binary.right) and ( 

667 not only_synonyms or c.name == binary.left.name 

668 ): 

669 omit.add(c) 

670 break 

671 

672 for clause in clauses: 

673 if clause is not None: 

674 visitors.traverse(clause, {}, {"binary": visit_binary}) 

675 

676 return ColumnSet(columns.difference(omit)) 

677 

678 

679def criterion_as_pairs( 

680 expression, 

681 consider_as_foreign_keys=None, 

682 consider_as_referenced_keys=None, 

683 any_operator=False, 

684): 

685 """traverse an expression and locate binary criterion pairs.""" 

686 

687 if consider_as_foreign_keys and consider_as_referenced_keys: 

688 raise exc.ArgumentError( 

689 "Can only specify one of " 

690 "'consider_as_foreign_keys' or " 

691 "'consider_as_referenced_keys'" 

692 ) 

693 

694 def col_is(a, b): 

695 # return a is b 

696 return a.compare(b) 

697 

698 def visit_binary(binary): 

699 if not any_operator and binary.operator is not operators.eq: 

700 return 

701 if not isinstance(binary.left, ColumnElement) or not isinstance( 

702 binary.right, ColumnElement 

703 ): 

704 return 

705 

706 if consider_as_foreign_keys: 

707 if binary.left in consider_as_foreign_keys and ( 

708 col_is(binary.right, binary.left) 

709 or binary.right not in consider_as_foreign_keys 

710 ): 

711 pairs.append((binary.right, binary.left)) 

712 elif binary.right in consider_as_foreign_keys and ( 

713 col_is(binary.left, binary.right) 

714 or binary.left not in consider_as_foreign_keys 

715 ): 

716 pairs.append((binary.left, binary.right)) 

717 elif consider_as_referenced_keys: 

718 if binary.left in consider_as_referenced_keys and ( 

719 col_is(binary.right, binary.left) 

720 or binary.right not in consider_as_referenced_keys 

721 ): 

722 pairs.append((binary.left, binary.right)) 

723 elif binary.right in consider_as_referenced_keys and ( 

724 col_is(binary.left, binary.right) 

725 or binary.left not in consider_as_referenced_keys 

726 ): 

727 pairs.append((binary.right, binary.left)) 

728 else: 

729 if isinstance(binary.left, Column) and isinstance( 

730 binary.right, Column 

731 ): 

732 if binary.left.references(binary.right): 

733 pairs.append((binary.right, binary.left)) 

734 elif binary.right.references(binary.left): 

735 pairs.append((binary.left, binary.right)) 

736 

737 pairs = [] 

738 visitors.traverse(expression, {}, {"binary": visit_binary}) 

739 return pairs 

740 

741 

742class ClauseAdapter(visitors.ReplacingCloningVisitor): 

743 """Clones and modifies clauses based on column correspondence. 

744 

745 E.g.:: 

746 

747 table1 = Table('sometable', metadata, 

748 Column('col1', Integer), 

749 Column('col2', Integer) 

750 ) 

751 table2 = Table('someothertable', metadata, 

752 Column('col1', Integer), 

753 Column('col2', Integer) 

754 ) 

755 

756 condition = table1.c.col1 == table2.c.col1 

757 

758 make an alias of table1:: 

759 

760 s = table1.alias('foo') 

761 

762 calling ``ClauseAdapter(s).traverse(condition)`` converts 

763 condition to read:: 

764 

765 s.c.col1 == table2.c.col1 

766 

767 """ 

768 

769 def __init__( 

770 self, 

771 selectable, 

772 equivalents=None, 

773 include_fn=None, 

774 exclude_fn=None, 

775 adapt_on_names=False, 

776 anonymize_labels=False, 

777 ): 

778 self.__traverse_options__ = { 

779 "stop_on": [selectable], 

780 "anonymize_labels": anonymize_labels, 

781 } 

782 self.selectable = selectable 

783 self.include_fn = include_fn 

784 self.exclude_fn = exclude_fn 

785 self.equivalents = util.column_dict(equivalents or {}) 

786 self.adapt_on_names = adapt_on_names 

787 

788 def _corresponding_column( 

789 self, col, require_embedded, _seen=util.EMPTY_SET 

790 ): 

791 newcol = self.selectable.corresponding_column( 

792 col, require_embedded=require_embedded 

793 ) 

794 if newcol is None and col in self.equivalents and col not in _seen: 

795 for equiv in self.equivalents[col]: 

796 newcol = self._corresponding_column( 

797 equiv, 

798 require_embedded=require_embedded, 

799 _seen=_seen.union([col]), 

800 ) 

801 if newcol is not None: 

802 return newcol 

803 if self.adapt_on_names and newcol is None: 

804 newcol = self.selectable.c.get(col.name) 

805 return newcol 

806 

807 def replace(self, col): 

808 if isinstance(col, FromClause) and self.selectable.is_derived_from( 

809 col 

810 ): 

811 return self.selectable 

812 elif not isinstance(col, ColumnElement): 

813 return None 

814 elif self.include_fn and not self.include_fn(col): 

815 return None 

816 elif self.exclude_fn and self.exclude_fn(col): 

817 return None 

818 else: 

819 return self._corresponding_column(col, True) 

820 

821 

822class ColumnAdapter(ClauseAdapter): 

823 """Extends ClauseAdapter with extra utility functions. 

824 

825 Key aspects of ColumnAdapter include: 

826 

827 * Expressions that are adapted are stored in a persistent 

828 .columns collection; so that an expression E adapted into 

829 an expression E1, will return the same object E1 when adapted 

830 a second time. This is important in particular for things like 

831 Label objects that are anonymized, so that the ColumnAdapter can 

832 be used to present a consistent "adapted" view of things. 

833 

834 * Exclusion of items from the persistent collection based on 

835 include/exclude rules, but also independent of hash identity. 

836 This because "annotated" items all have the same hash identity as their 

837 parent. 

838 

839 * "wrapping" capability is added, so that the replacement of an expression 

840 E can proceed through a series of adapters. This differs from the 

841 visitor's "chaining" feature in that the resulting object is passed 

842 through all replacing functions unconditionally, rather than stopping 

843 at the first one that returns non-None. 

844 

845 * An adapt_required option, used by eager loading to indicate that 

846 We don't trust a result row column that is not translated. 

847 This is to prevent a column from being interpreted as that 

848 of the child row in a self-referential scenario, see 

849 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency 

850 

851 """ 

852 

853 def __init__( 

854 self, 

855 selectable, 

856 equivalents=None, 

857 adapt_required=False, 

858 include_fn=None, 

859 exclude_fn=None, 

860 adapt_on_names=False, 

861 allow_label_resolve=True, 

862 anonymize_labels=False, 

863 ): 

864 ClauseAdapter.__init__( 

865 self, 

866 selectable, 

867 equivalents, 

868 include_fn=include_fn, 

869 exclude_fn=exclude_fn, 

870 adapt_on_names=adapt_on_names, 

871 anonymize_labels=anonymize_labels, 

872 ) 

873 

874 self.columns = util.WeakPopulateDict(self._locate_col) 

875 if self.include_fn or self.exclude_fn: 

876 self.columns = self._IncludeExcludeMapping(self, self.columns) 

877 self.adapt_required = adapt_required 

878 self.allow_label_resolve = allow_label_resolve 

879 self._wrap = None 

880 

881 class _IncludeExcludeMapping(object): 

882 def __init__(self, parent, columns): 

883 self.parent = parent 

884 self.columns = columns 

885 

886 def __getitem__(self, key): 

887 if ( 

888 self.parent.include_fn and not self.parent.include_fn(key) 

889 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): 

890 if self.parent._wrap: 

891 return self.parent._wrap.columns[key] 

892 else: 

893 return key 

894 return self.columns[key] 

895 

896 def wrap(self, adapter): 

897 ac = self.__class__.__new__(self.__class__) 

898 ac.__dict__.update(self.__dict__) 

899 ac._wrap = adapter 

900 ac.columns = util.WeakPopulateDict(ac._locate_col) 

901 if ac.include_fn or ac.exclude_fn: 

902 ac.columns = self._IncludeExcludeMapping(ac, ac.columns) 

903 

904 return ac 

905 

906 def traverse(self, obj): 

907 return self.columns[obj] 

908 

909 adapt_clause = traverse 

910 adapt_list = ClauseAdapter.copy_and_process 

911 

912 def _locate_col(self, col): 

913 

914 c = ClauseAdapter.traverse(self, col) 

915 

916 if self._wrap: 

917 c2 = self._wrap._locate_col(c) 

918 if c2 is not None: 

919 c = c2 

920 

921 if self.adapt_required and c is col: 

922 return None 

923 

924 c._allow_label_resolve = self.allow_label_resolve 

925 

926 return c 

927 

928 def __getstate__(self): 

929 d = self.__dict__.copy() 

930 del d["columns"] 

931 return d 

932 

933 def __setstate__(self, state): 

934 self.__dict__.update(state) 

935 self.columns = util.WeakPopulateDict(self._locate_col)