Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# sql/visitors.py 

2# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: http://www.opensource.org/licenses/mit-license.php 

7 

8"""Visitor/traversal interface and library functions. 

9 

10SQLAlchemy schema and expression constructs rely on a Python-centric 

11version of the classic "visitor" pattern as the primary way in which 

12they apply functionality. The most common use of this pattern 

13is statement compilation, where individual expression classes match 

14up to rendering methods that produce a string result. Beyond this, 

15the visitor system is also used to inspect expressions for various 

16information and patterns, as well as for the purposes of applying 

17transformations to expressions. 

18 

19Examples of how the visit system is used can be seen in the source code 

20of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler`` 

21modules. Some background on clause adaption is also at 

22http://techspot.zzzeek.org/2008/01/23/expression-transformations/ . 

23 

24""" 

25 

26from collections import deque 

27import operator 

28 

29from .. import exc 

30from .. import util 

31 

32 

33__all__ = [ 

34 "VisitableType", 

35 "Visitable", 

36 "ClauseVisitor", 

37 "CloningVisitor", 

38 "ReplacingCloningVisitor", 

39 "iterate", 

40 "iterate_depthfirst", 

41 "traverse_using", 

42 "traverse", 

43 "traverse_depthfirst", 

44 "cloned_traverse", 

45 "replacement_traverse", 

46] 

47 

48 

49class VisitableType(type): 

50 """Metaclass which assigns a ``_compiler_dispatch`` method to classes 

51 having a ``__visit_name__`` attribute. 

52 

53 The ``_compiler_dispatch`` attribute becomes an instance method which 

54 looks approximately like the following:: 

55 

56 def _compiler_dispatch (self, visitor, **kw): 

57 '''Look for an attribute named "visit_" + self.__visit_name__ 

58 on the visitor, and call it with the same kw params.''' 

59 visit_attr = 'visit_%s' % self.__visit_name__ 

60 return getattr(visitor, visit_attr)(self, **kw) 

61 

62 Classes having no ``__visit_name__`` attribute will remain unaffected. 

63 

64 """ 

65 

66 def __init__(cls, clsname, bases, clsdict): 

67 if clsname != "Visitable" and hasattr(cls, "__visit_name__"): 

68 _generate_dispatch(cls) 

69 

70 super(VisitableType, cls).__init__(clsname, bases, clsdict) 

71 

72 

73def _generate_dispatch(cls): 

74 """Return an optimized visit dispatch function for the cls 

75 for use by the compiler. 

76 """ 

77 if "__visit_name__" in cls.__dict__: 

78 visit_name = cls.__visit_name__ 

79 

80 if isinstance(visit_name, util.compat.string_types): 

81 # There is an optimization opportunity here because the 

82 # the string name of the class's __visit_name__ is known at 

83 # this early stage (import time) so it can be pre-constructed. 

84 getter = operator.attrgetter("visit_%s" % visit_name) 

85 

86 def _compiler_dispatch(self, visitor, **kw): 

87 try: 

88 meth = getter(visitor) 

89 except AttributeError as err: 

90 util.raise_( 

91 exc.UnsupportedCompilationError(visitor, cls), 

92 replace_context=err, 

93 ) 

94 else: 

95 return meth(self, **kw) 

96 

97 else: 

98 # The optimization opportunity is lost for this case because the 

99 # __visit_name__ is not yet a string. As a result, the visit 

100 # string has to be recalculated with each compilation. 

101 def _compiler_dispatch(self, visitor, **kw): 

102 visit_attr = "visit_%s" % self.__visit_name__ 

103 try: 

104 meth = getattr(visitor, visit_attr) 

105 except AttributeError as err: 

106 util.raise_( 

107 exc.UnsupportedCompilationError(visitor, cls), 

108 replace_context=err, 

109 ) 

110 else: 

111 return meth(self, **kw) 

112 

113 _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ 

114 on the visitor, and call it with the same kw params. 

115 """ 

116 cls._compiler_dispatch = _compiler_dispatch 

117 

118 

119class Visitable(util.with_metaclass(VisitableType, object)): 

120 """Base class for visitable objects, applies the 

121 :class:`.visitors.VisitableType` metaclass. 

122 

123 The :class:`.Visitable` class is essentially at the base of the 

124 :class:`_expression.ClauseElement` hierarchy. 

125 

126 """ 

127 

128 

129class ClauseVisitor(object): 

130 """Base class for visitor objects which can traverse using 

131 the :func:`.visitors.traverse` function. 

132 

133 Direct usage of the :func:`.visitors.traverse` function is usually 

134 preferred. 

135 

136 """ 

137 

138 __traverse_options__ = {} 

139 

140 def traverse_single(self, obj, **kw): 

141 for v in self.visitor_iterator: 

142 meth = getattr(v, "visit_%s" % obj.__visit_name__, None) 

143 if meth: 

144 return meth(obj, **kw) 

145 

146 def iterate(self, obj): 

147 """traverse the given expression structure, returning an iterator 

148 of all elements. 

149 

150 """ 

151 return iterate(obj, self.__traverse_options__) 

152 

153 def traverse(self, obj): 

154 """traverse and visit the given expression structure.""" 

155 

156 return traverse(obj, self.__traverse_options__, self._visitor_dict) 

157 

158 @util.memoized_property 

159 def _visitor_dict(self): 

160 visitors = {} 

161 

162 for name in dir(self): 

163 if name.startswith("visit_"): 

164 visitors[name[6:]] = getattr(self, name) 

165 return visitors 

166 

167 @property 

168 def visitor_iterator(self): 

169 """iterate through this visitor and each 'chained' visitor.""" 

170 

171 v = self 

172 while v: 

173 yield v 

174 v = getattr(v, "_next", None) 

175 

176 def chain(self, visitor): 

177 """'chain' an additional ClauseVisitor onto this ClauseVisitor. 

178 

179 the chained visitor will receive all visit events after this one. 

180 

181 """ 

182 tail = list(self.visitor_iterator)[-1] 

183 tail._next = visitor 

184 return self 

185 

186 

187class CloningVisitor(ClauseVisitor): 

188 """Base class for visitor objects which can traverse using 

189 the :func:`.visitors.cloned_traverse` function. 

190 

191 Direct usage of the :func:`.visitors.cloned_traverse` function is usually 

192 preferred. 

193 

194 

195 """ 

196 

197 def copy_and_process(self, list_): 

198 """Apply cloned traversal to the given list of elements, and return 

199 the new list. 

200 

201 """ 

202 return [self.traverse(x) for x in list_] 

203 

204 def traverse(self, obj): 

205 """traverse and visit the given expression structure.""" 

206 

207 return cloned_traverse( 

208 obj, self.__traverse_options__, self._visitor_dict 

209 ) 

210 

211 

212class ReplacingCloningVisitor(CloningVisitor): 

213 """Base class for visitor objects which can traverse using 

214 the :func:`.visitors.replacement_traverse` function. 

215 

216 Direct usage of the :func:`.visitors.replacement_traverse` function is 

217 usually preferred. 

218 

219 """ 

220 

221 def replace(self, elem): 

222 """receive pre-copied elements during a cloning traversal. 

223 

224 If the method returns a new element, the element is used 

225 instead of creating a simple copy of the element. Traversal 

226 will halt on the newly returned element if it is re-encountered. 

227 """ 

228 return None 

229 

230 def traverse(self, obj): 

231 """traverse and visit the given expression structure.""" 

232 

233 def replace(elem): 

234 for v in self.visitor_iterator: 

235 e = v.replace(elem) 

236 if e is not None: 

237 return e 

238 

239 return replacement_traverse(obj, self.__traverse_options__, replace) 

240 

241 

242def iterate(obj, opts): 

243 r"""traverse the given expression structure, returning an iterator. 

244 

245 traversal is configured to be breadth-first. 

246 

247 The central API feature used by the :func:`.visitors.iterate` and 

248 :func:`.visitors.iterate_depthfirst` functions is the 

249 :meth:`_expression.ClauseElement.get_children` method of 

250 :class:`_expression.ClauseElement` 

251 objects. This method should return all the 

252 :class:`_expression.ClauseElement` objects 

253 which are associated with a particular :class:`_expression.ClauseElement` 

254 object. 

255 For example, a :class:`.Case` structure will refer to a series of 

256 :class:`_expression.ColumnElement` 

257 objects within its "whens" and "else\_" member 

258 variables. 

259 

260 :param obj: :class:`_expression.ClauseElement` structure to be traversed 

261 

262 :param opts: dictionary of iteration options. This dictionary is usually 

263 empty in modern usage. 

264 

265 """ 

266 # fasttrack for atomic elements like columns 

267 children = obj.get_children(**opts) 

268 if not children: 

269 return [obj] 

270 

271 traversal = deque() 

272 stack = deque([obj]) 

273 while stack: 

274 t = stack.popleft() 

275 traversal.append(t) 

276 for c in t.get_children(**opts): 

277 stack.append(c) 

278 return iter(traversal) 

279 

280 

281def iterate_depthfirst(obj, opts): 

282 """traverse the given expression structure, returning an iterator. 

283 

284 traversal is configured to be depth-first. 

285 

286 :param obj: :class:`_expression.ClauseElement` structure to be traversed 

287 

288 :param opts: dictionary of iteration options. This dictionary is usually 

289 empty in modern usage. 

290 

291 .. seealso:: 

292 

293 :func:`.visitors.iterate` - includes a general overview of iteration. 

294 

295 """ 

296 # fasttrack for atomic elements like columns 

297 children = obj.get_children(**opts) 

298 if not children: 

299 return [obj] 

300 

301 stack = deque([obj]) 

302 traversal = deque() 

303 while stack: 

304 t = stack.pop() 

305 traversal.appendleft(t) 

306 for c in t.get_children(**opts): 

307 stack.append(c) 

308 return iter(traversal) 

309 

310 

311def traverse_using(iterator, obj, visitors): 

312 """visit the given expression structure using the given iterator of 

313 objects. 

314 

315 :func:`.visitors.traverse_using` is usually called internally as the result 

316 of the :func:`.visitors.traverse` or :func:`.visitors.traverse_depthfirst` 

317 functions. 

318 

319 :param iterator: an iterable or sequence which will yield 

320 :class:`_expression.ClauseElement` 

321 structures; the iterator is assumed to be the 

322 product of the :func:`.visitors.iterate` or 

323 :func:`.visitors.iterate_depthfirst` functions. 

324 

325 :param obj: the :class:`_expression.ClauseElement` 

326 that was used as the target of the 

327 :func:`.iterate` or :func:`.iterate_depthfirst` function. 

328 

329 :param visitors: dictionary of visit functions. See :func:`.traverse` 

330 for details on this dictionary. 

331 

332 .. seealso:: 

333 

334 :func:`.traverse` 

335 

336 :func:`.traverse_depthfirst` 

337 

338 """ 

339 for target in iterator: 

340 meth = visitors.get(target.__visit_name__, None) 

341 if meth: 

342 meth(target) 

343 return obj 

344 

345 

346def traverse(obj, opts, visitors): 

347 """traverse and visit the given expression structure using the default 

348 iterator. 

349 

350 e.g.:: 

351 

352 from sqlalchemy.sql import visitors 

353 

354 stmt = select([some_table]).where(some_table.c.foo == 'bar') 

355 

356 def visit_bindparam(bind_param): 

357 print("found bound value: %s" % bind_param.value) 

358 

359 visitors.traverse(stmt, {}, {"bindparam": visit_bindparam}) 

360 

361 The iteration of objects uses the :func:`.visitors.iterate` function, 

362 which does a breadth-first traversal using a stack. 

363 

364 :param obj: :class:`_expression.ClauseElement` structure to be traversed 

365 

366 :param opts: dictionary of iteration options. This dictionary is usually 

367 empty in modern usage. 

368 

369 :param visitors: dictionary of visit functions. The dictionary should 

370 have strings as keys, each of which would correspond to the 

371 ``__visit_name__`` of a particular kind of SQL expression object, and 

372 callable functions as values, each of which represents a visitor function 

373 for that kind of object. 

374 

375 """ 

376 return traverse_using(iterate(obj, opts), obj, visitors) 

377 

378 

379def traverse_depthfirst(obj, opts, visitors): 

380 """traverse and visit the given expression structure using the 

381 depth-first iterator. 

382 

383 The iteration of objects uses the :func:`.visitors.iterate_depthfirst` 

384 function, which does a depth-first traversal using a stack. 

385 

386 Usage is the same as that of :func:`.visitors.traverse` function. 

387 

388 

389 """ 

390 return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) 

391 

392 

393def cloned_traverse(obj, opts, visitors): 

394 """clone the given expression structure, allowing modifications by 

395 visitors. 

396 

397 Traversal usage is the same as that of :func:`.visitors.traverse`. 

398 The visitor functions present in the ``visitors`` dictionary may also 

399 modify the internals of the given structure as the traversal proceeds. 

400 

401 The central API feature used by the :func:`.visitors.cloned_traverse` 

402 and :func:`.visitors.replacement_traverse` functions, in addition to the 

403 :meth:`_expression.ClauseElement.get_children` 

404 function that is used to achieve 

405 the iteration, is the :meth:`_expression.ClauseElement._copy_internals` 

406 method. 

407 For a :class:`_expression.ClauseElement` 

408 structure to support cloning and replacement 

409 traversals correctly, it needs to be able to pass a cloning function into 

410 its internal members in order to make copies of them. 

411 

412 .. seealso:: 

413 

414 :func:`.visitors.traverse` 

415 

416 :func:`.visitors.replacement_traverse` 

417 

418 """ 

419 

420 cloned = {} 

421 stop_on = set(opts.get("stop_on", [])) 

422 

423 def clone(elem, **kw): 

424 if elem in stop_on: 

425 return elem 

426 else: 

427 if id(elem) not in cloned: 

428 cloned[id(elem)] = newelem = elem._clone() 

429 newelem._copy_internals(clone=clone, **kw) 

430 meth = visitors.get(newelem.__visit_name__, None) 

431 if meth: 

432 meth(newelem) 

433 return cloned[id(elem)] 

434 

435 if obj is not None: 

436 obj = clone(obj) 

437 

438 clone = None # remove gc cycles 

439 

440 return obj 

441 

442 

443def replacement_traverse(obj, opts, replace): 

444 """clone the given expression structure, allowing element 

445 replacement by a given replacement function. 

446 

447 This function is very similar to the :func:`.visitors.cloned_traverse` 

448 function, except instead of being passed a dictionary of visitors, all 

449 elements are unconditionally passed into the given replace function. 

450 The replace function then has the option to return an entirely new object 

451 which will replace the one given. if it returns ``None``, then the object 

452 is kept in place. 

453 

454 The difference in usage between :func:`.visitors.cloned_traverse` and 

455 :func:`.visitors.replacement_traverse` is that in the former case, an 

456 already-cloned object is passed to the visitor function, and the visitor 

457 function can then manipulate the internal state of the object. 

458 In the case of the latter, the visitor function should only return an 

459 entirely different object, or do nothing. 

460 

461 The use case for :func:`.visitors.replacement_traverse` is that of 

462 replacing a FROM clause inside of a SQL structure with a different one, 

463 as is a common use case within the ORM. 

464 

465 """ 

466 

467 cloned = {} 

468 stop_on = {id(x) for x in opts.get("stop_on", [])} 

469 

470 def clone(elem, **kw): 

471 if ( 

472 id(elem) in stop_on 

473 or "no_replacement_traverse" in elem._annotations 

474 ): 

475 return elem 

476 else: 

477 newelem = replace(elem) 

478 if newelem is not None: 

479 stop_on.add(id(newelem)) 

480 return newelem 

481 else: 

482 if elem not in cloned: 

483 cloned[elem] = newelem = elem._clone() 

484 newelem._copy_internals(clone=clone, **kw) 

485 return cloned[elem] 

486 

487 if obj is not None: 

488 obj = clone(obj, **opts) 

489 clone = None # remove gc cycles 

490 return obj