Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import re 

2from . import err 

3 

4 

5#: Regular expression for :meth:`Cursor.executemany`. 

6#: executemany only supports simple bulk insert. 

7#: You can use it to load large dataset. 

8RE_INSERT_VALUES = re.compile( 

9 r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" 

10 + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" 

11 + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", 

12 re.IGNORECASE | re.DOTALL, 

13) 

14 

15 

16class Cursor: 

17 """ 

18 This is the object you use to interact with the database. 

19 

20 Do not create an instance of a Cursor yourself. Call 

21 connections.Connection.cursor(). 

22 

23 See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in 

24 the specification. 

25 """ 

26 

27 #: Max statement size which :meth:`executemany` generates. 

28 #: 

29 #: Max size of allowed statement is max_allowed_packet - packet_header_size. 

30 #: Default value of max_allowed_packet is 1048576. 

31 max_stmt_length = 1024000 

32 

33 def __init__(self, connection): 

34 self.connection = connection 

35 self.description = None 

36 self.rownumber = 0 

37 self.rowcount = -1 

38 self.arraysize = 1 

39 self._executed = None 

40 self._result = None 

41 self._rows = None 

42 

43 def close(self): 

44 """ 

45 Closing a cursor just exhausts all remaining data. 

46 """ 

47 conn = self.connection 

48 if conn is None: 

49 return 

50 try: 

51 while self.nextset(): 

52 pass 

53 finally: 

54 self.connection = None 

55 

56 def __enter__(self): 

57 return self 

58 

59 def __exit__(self, *exc_info): 

60 del exc_info 

61 self.close() 

62 

63 def _get_db(self): 

64 if not self.connection: 

65 raise err.ProgrammingError("Cursor closed") 

66 return self.connection 

67 

68 def _check_executed(self): 

69 if not self._executed: 

70 raise err.ProgrammingError("execute() first") 

71 

72 def _conv_row(self, row): 

73 return row 

74 

75 def setinputsizes(self, *args): 

76 """Does nothing, required by DB API.""" 

77 

78 def setoutputsizes(self, *args): 

79 """Does nothing, required by DB API.""" 

80 

81 def _nextset(self, unbuffered=False): 

82 """Get the next query set""" 

83 conn = self._get_db() 

84 current_result = self._result 

85 if current_result is None or current_result is not conn._result: 

86 return None 

87 if not current_result.has_next: 

88 return None 

89 self._result = None 

90 self._clear_result() 

91 conn.next_result(unbuffered=unbuffered) 

92 self._do_get_result() 

93 return True 

94 

95 def nextset(self): 

96 return self._nextset(False) 

97 

98 def _ensure_bytes(self, x, encoding=None): 

99 if isinstance(x, str): 

100 x = x.encode(encoding) 

101 elif isinstance(x, (tuple, list)): 

102 x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) 

103 return x 

104 

105 def _escape_args(self, args, conn): 

106 if isinstance(args, (tuple, list)): 

107 return tuple(conn.literal(arg) for arg in args) 

108 elif isinstance(args, dict): 

109 return {key: conn.literal(val) for (key, val) in args.items()} 

110 else: 

111 # If it's not a dictionary let's try escaping it anyways. 

112 # Worst case it will throw a Value error 

113 return conn.escape(args) 

114 

115 def mogrify(self, query, args=None): 

116 """ 

117 Returns the exact string that is sent to the database by calling the 

118 execute() method. 

119 

120 This method follows the extension to the DB API 2.0 followed by Psycopg. 

121 """ 

122 conn = self._get_db() 

123 

124 if args is not None: 

125 query = query % self._escape_args(args, conn) 

126 

127 return query 

128 

129 def execute(self, query, args=None): 

130 """Execute a query 

131 

132 :param str query: Query to execute. 

133 

134 :param args: parameters used with query. (optional) 

135 :type args: tuple, list or dict 

136 

137 :return: Number of affected rows 

138 :rtype: int 

139 

140 If args is a list or tuple, %s can be used as a placeholder in the query. 

141 If args is a dict, %(name)s can be used as a placeholder in the query. 

142 """ 

143 while self.nextset(): 

144 pass 

145 

146 query = self.mogrify(query, args) 

147 

148 result = self._query(query) 

149 self._executed = query 

150 return result 

151 

152 def executemany(self, query, args): 

153 # type: (str, list) -> int 

154 """Run several data against one query 

155 

156 :param query: query to execute on server 

157 :param args: Sequence of sequences or mappings. It is used as parameter. 

158 :return: Number of rows affected, if any. 

159 

160 This method improves performance on multiple-row INSERT and 

161 REPLACE. Otherwise it is equivalent to looping over args with 

162 execute(). 

163 """ 

164 if not args: 

165 return 

166 

167 m = RE_INSERT_VALUES.match(query) 

168 if m: 

169 q_prefix = m.group(1) % () 

170 q_values = m.group(2).rstrip() 

171 q_postfix = m.group(3) or "" 

172 assert q_values[0] == "(" and q_values[-1] == ")" 

173 return self._do_execute_many( 

174 q_prefix, 

175 q_values, 

176 q_postfix, 

177 args, 

178 self.max_stmt_length, 

179 self._get_db().encoding, 

180 ) 

181 

182 self.rowcount = sum(self.execute(query, arg) for arg in args) 

183 return self.rowcount 

184 

185 def _do_execute_many( 

186 self, prefix, values, postfix, args, max_stmt_length, encoding 

187 ): 

188 conn = self._get_db() 

189 escape = self._escape_args 

190 if isinstance(prefix, str): 

191 prefix = prefix.encode(encoding) 

192 if isinstance(postfix, str): 

193 postfix = postfix.encode(encoding) 

194 sql = bytearray(prefix) 

195 args = iter(args) 

196 v = values % escape(next(args), conn) 

197 if isinstance(v, str): 

198 v = v.encode(encoding, "surrogateescape") 

199 sql += v 

200 rows = 0 

201 for arg in args: 

202 v = values % escape(arg, conn) 

203 if isinstance(v, str): 

204 v = v.encode(encoding, "surrogateescape") 

205 if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: 

206 rows += self.execute(sql + postfix) 

207 sql = bytearray(prefix) 

208 else: 

209 sql += b"," 

210 sql += v 

211 rows += self.execute(sql + postfix) 

212 self.rowcount = rows 

213 return rows 

214 

215 def callproc(self, procname, args=()): 

216 """Execute stored procedure procname with args 

217 

218 procname -- string, name of procedure to execute on server 

219 

220 args -- Sequence of parameters to use with procedure 

221 

222 Returns the original args. 

223 

224 Compatibility warning: PEP-249 specifies that any modified 

225 parameters must be returned. This is currently impossible 

226 as they are only available by storing them in a server 

227 variable and then retrieved by a query. Since stored 

228 procedures return zero or more result sets, there is no 

229 reliable way to get at OUT or INOUT parameters via callproc. 

230 The server variables are named @_procname_n, where procname 

231 is the parameter above and n is the position of the parameter 

232 (from zero). Once all result sets generated by the procedure 

233 have been fetched, you can issue a SELECT @_procname_0, ... 

234 query using .execute() to get any OUT or INOUT values. 

235 

236 Compatibility warning: The act of calling a stored procedure 

237 itself creates an empty result set. This appears after any 

238 result sets generated by the procedure. This is non-standard 

239 behavior with respect to the DB-API. Be sure to use nextset() 

240 to advance through all result sets; otherwise you may get 

241 disconnected. 

242 """ 

243 conn = self._get_db() 

244 if args: 

245 fmt = f"@_{procname}_%d=%s" 

246 self._query( 

247 "SET %s" 

248 % ",".join( 

249 fmt % (index, conn.escape(arg)) for index, arg in enumerate(args) 

250 ) 

251 ) 

252 self.nextset() 

253 

254 q = "CALL %s(%s)" % ( 

255 procname, 

256 ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]), 

257 ) 

258 self._query(q) 

259 self._executed = q 

260 return args 

261 

262 def fetchone(self): 

263 """Fetch the next row""" 

264 self._check_executed() 

265 if self._rows is None or self.rownumber >= len(self._rows): 

266 return None 

267 result = self._rows[self.rownumber] 

268 self.rownumber += 1 

269 return result 

270 

271 def fetchmany(self, size=None): 

272 """Fetch several rows""" 

273 self._check_executed() 

274 if self._rows is None: 

275 return () 

276 end = self.rownumber + (size or self.arraysize) 

277 result = self._rows[self.rownumber : end] 

278 self.rownumber = min(end, len(self._rows)) 

279 return result 

280 

281 def fetchall(self): 

282 """Fetch all the rows""" 

283 self._check_executed() 

284 if self._rows is None: 

285 return () 

286 if self.rownumber: 

287 result = self._rows[self.rownumber :] 

288 else: 

289 result = self._rows 

290 self.rownumber = len(self._rows) 

291 return result 

292 

293 def scroll(self, value, mode="relative"): 

294 self._check_executed() 

295 if mode == "relative": 

296 r = self.rownumber + value 

297 elif mode == "absolute": 

298 r = value 

299 else: 

300 raise err.ProgrammingError("unknown scroll mode %s" % mode) 

301 

302 if not (0 <= r < len(self._rows)): 

303 raise IndexError("out of range") 

304 self.rownumber = r 

305 

306 def _query(self, q): 

307 conn = self._get_db() 

308 self._last_executed = q 

309 self._clear_result() 

310 conn.query(q) 

311 self._do_get_result() 

312 return self.rowcount 

313 

314 def _clear_result(self): 

315 self.rownumber = 0 

316 self._result = None 

317 

318 self.rowcount = 0 

319 self.description = None 

320 self.lastrowid = None 

321 self._rows = None 

322 

323 def _do_get_result(self): 

324 conn = self._get_db() 

325 

326 self._result = result = conn._result 

327 

328 self.rowcount = result.affected_rows 

329 self.description = result.description 

330 self.lastrowid = result.insert_id 

331 self._rows = result.rows 

332 

333 def __iter__(self): 

334 return iter(self.fetchone, None) 

335 

336 Warning = err.Warning 

337 Error = err.Error 

338 InterfaceError = err.InterfaceError 

339 DatabaseError = err.DatabaseError 

340 DataError = err.DataError 

341 OperationalError = err.OperationalError 

342 IntegrityError = err.IntegrityError 

343 InternalError = err.InternalError 

344 ProgrammingError = err.ProgrammingError 

345 NotSupportedError = err.NotSupportedError 

346 

347 

348class DictCursorMixin: 

349 # You can override this to use OrderedDict or other dict-like types. 

350 dict_type = dict 

351 

352 def _do_get_result(self): 

353 super(DictCursorMixin, self)._do_get_result() 

354 fields = [] 

355 if self.description: 

356 for f in self._result.fields: 

357 name = f.name 

358 if name in fields: 

359 name = f.table_name + "." + name 

360 fields.append(name) 

361 self._fields = fields 

362 

363 if fields and self._rows: 

364 self._rows = [self._conv_row(r) for r in self._rows] 

365 

366 def _conv_row(self, row): 

367 if row is None: 

368 return None 

369 return self.dict_type(zip(self._fields, row)) 

370 

371 

372class DictCursor(DictCursorMixin, Cursor): 

373 """A cursor which returns results as a dictionary""" 

374 

375 

376class SSCursor(Cursor): 

377 """ 

378 Unbuffered Cursor, mainly useful for queries that return a lot of data, 

379 or for connections to remote servers over a slow network. 

380 

381 Instead of copying every row of data into a buffer, this will fetch 

382 rows as needed. The upside of this is the client uses much less memory, 

383 and rows are returned much faster when traveling over a slow network 

384 or if the result set is very big. 

385 

386 There are limitations, though. The MySQL protocol doesn't support 

387 returning the total number of rows, so the only way to tell how many rows 

388 there are is to iterate over every row returned. Also, it currently isn't 

389 possible to scroll backwards, as only the current row is held in memory. 

390 """ 

391 

392 def _conv_row(self, row): 

393 return row 

394 

395 def close(self): 

396 conn = self.connection 

397 if conn is None: 

398 return 

399 

400 if self._result is not None and self._result is conn._result: 

401 self._result._finish_unbuffered_query() 

402 

403 try: 

404 while self.nextset(): 

405 pass 

406 finally: 

407 self.connection = None 

408 

409 __del__ = close 

410 

411 def _query(self, q): 

412 conn = self._get_db() 

413 self._last_executed = q 

414 self._clear_result() 

415 conn.query(q, unbuffered=True) 

416 self._do_get_result() 

417 return self.rowcount 

418 

419 def nextset(self): 

420 return self._nextset(unbuffered=True) 

421 

422 def read_next(self): 

423 """Read next row""" 

424 return self._conv_row(self._result._read_rowdata_packet_unbuffered()) 

425 

426 def fetchone(self): 

427 """Fetch next row""" 

428 self._check_executed() 

429 row = self.read_next() 

430 if row is None: 

431 return None 

432 self.rownumber += 1 

433 return row 

434 

435 def fetchall(self): 

436 """ 

437 Fetch all, as per MySQLdb. Pretty useless for large queries, as 

438 it is buffered. See fetchall_unbuffered(), if you want an unbuffered 

439 generator version of this method. 

440 """ 

441 return list(self.fetchall_unbuffered()) 

442 

443 def fetchall_unbuffered(self): 

444 """ 

445 Fetch all, implemented as a generator, which isn't to standard, 

446 however, it doesn't make sense to return everything in a list, as that 

447 would use ridiculous memory for large result sets. 

448 """ 

449 return iter(self.fetchone, None) 

450 

451 def __iter__(self): 

452 return self.fetchall_unbuffered() 

453 

454 def fetchmany(self, size=None): 

455 """Fetch many""" 

456 self._check_executed() 

457 if size is None: 

458 size = self.arraysize 

459 

460 rows = [] 

461 for i in range(size): 

462 row = self.read_next() 

463 if row is None: 

464 break 

465 rows.append(row) 

466 self.rownumber += 1 

467 return rows 

468 

469 def scroll(self, value, mode="relative"): 

470 self._check_executed() 

471 

472 if mode == "relative": 

473 if value < 0: 

474 raise err.NotSupportedError( 

475 "Backwards scrolling not supported by this cursor" 

476 ) 

477 

478 for _ in range(value): 

479 self.read_next() 

480 self.rownumber += value 

481 elif mode == "absolute": 

482 if value < self.rownumber: 

483 raise err.NotSupportedError( 

484 "Backwards scrolling not supported by this cursor" 

485 ) 

486 

487 end = value - self.rownumber 

488 for _ in range(end): 

489 self.read_next() 

490 self.rownumber = value 

491 else: 

492 raise err.ProgrammingError("unknown scroll mode %s" % mode) 

493 

494 

495class SSDictCursor(DictCursorMixin, SSCursor): 

496 """An unbuffered cursor, which returns results as a dictionary"""