0001import atexit
0002import inspect
0003import sys
0004import os
0005import threading
0006import types
0007try:
0008 from urlparse import urlparse, parse_qsl
0009 from urllib import unquote, quote, urlencode
0010except ImportError:
0011 from urllib.parse import urlparse, parse_qsl, unquote, quote, urlencode
0012
0013import warnings
0014import weakref
0015
0016from .cache import CacheSet
0017from . import classregistry
0018from . import col
0019from .converters import sqlrepr
0020from . import sqlbuilder
0021from .util.threadinglocal import local as threading_local
0022from .compat import PY2, string_type, unicode_type
0023
0024warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0025
0026_connections = {}
0027
0028
0029def _closeConnection(ref):
0030 conn = ref()
0031 if conn is not None:
0032 conn.close()
0033
0034
0035class ConsoleWriter:
0036 def __init__(self, connection, loglevel):
0037
0038 self.loglevel = loglevel or "stdout"
0039 self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0040
0041 def write(self, text):
0042 logfile = getattr(sys, self.loglevel)
0043 if PY2 and isinstance(text, unicode_type):
0044 try:
0045 text = text.encode(self.dbEncoding)
0046 except UnicodeEncodeError:
0047 text = repr(text)[2:-1]
0048 logfile.write(text + '\n')
0049
0050
0051class LogWriter:
0052 def __init__(self, connection, logger, loglevel):
0053 self.logger = logger
0054 self.loglevel = loglevel
0055 self.logmethod = getattr(logger, loglevel)
0056
0057 def write(self, text):
0058 self.logmethod(text)
0059
0060
0061def makeDebugWriter(connection, loggerName, loglevel):
0062 if not loggerName:
0063 return ConsoleWriter(connection, loglevel)
0064 import logging
0065 logger = logging.getLogger(loggerName)
0066 return LogWriter(connection, logger, loglevel)
0067
0068
0069class Boolean(object):
0070 """A bool class that also understands some special string keywords
0071
0072 Understands: yes/no, true/false, on/off, 1/0, case ignored.
0073
0074 """
0075 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0076 '0': False, 'no': False, 'false': False, 'off': False}
0077
0078 def __new__(cls, value):
0079 try:
0080 return Boolean._keywords[value.lower()]
0081 except (AttributeError, KeyError):
0082 return bool(value)
0083
0084
0085class DBConnection:
0086
0087 def __init__(self, name=None, debug=False, debugOutput=False,
0088 cache=True, style=None, autoCommit=True,
0089 debugThreading=False, registry=None,
0090 logger=None, loglevel=None):
0091 self.name = name
0092 self.debug = Boolean(debug)
0093 self.debugOutput = Boolean(debugOutput)
0094 self.debugThreading = Boolean(debugThreading)
0095 self.debugWriter = makeDebugWriter(self, logger, loglevel)
0096 self.doCache = Boolean(cache)
0097 self.cache = CacheSet(cache=self.doCache)
0098 self.style = style
0099 self._connectionNumbers = {}
0100 self._connectionCount = 1
0101 self.autoCommit = Boolean(autoCommit)
0102 self.registry = registry or None
0103 classregistry.registry(self.registry).addCallback(self.soClassAdded)
0104 registerConnectionInstance(self)
0105 atexit.register(_closeConnection, weakref.ref(self))
0106
0107 def oldUri(self):
0108 auth = getattr(self, 'user', '') or ''
0109 if auth:
0110 if self.password:
0111 auth = auth + ':' + self.password
0112 auth = auth + '@'
0113 else:
0114 assert not getattr(self, 'password', None), (
0115 'URIs cannot express passwords without usernames')
0116 uri = '%s://%s' % (self.dbName, auth)
0117 if self.host:
0118 uri += self.host
0119 if self.port:
0120 uri += ':%d' % self.port
0121 uri += '/'
0122 db = self.db
0123 if db.startswith('/'):
0124 db = db[1:]
0125 return uri + db
0126
0127 def uri(self):
0128 auth = getattr(self, 'user', '') or ''
0129 if auth:
0130 auth = quote(auth)
0131 if self.password:
0132 auth = auth + ':' + quote(self.password)
0133 auth = auth + '@'
0134 else:
0135 assert not getattr(self, 'password', None), (
0136 'URIs cannot express passwords without usernames')
0137 uri = '%s://%s' % (self.dbName, auth)
0138 if self.host:
0139 uri += self.host
0140 if self.port:
0141 uri += ':%d' % self.port
0142 uri += '/'
0143 db = self.db
0144 if db.startswith('/'):
0145 db = db[1:]
0146 return uri + quote(db)
0147
0148 @classmethod
0149 def connectionFromOldURI(cls, uri):
0150 return cls._connectionFromParams(*cls._parseOldURI(uri))
0151
0152 @classmethod
0153 def connectionFromURI(cls, uri):
0154 return cls._connectionFromParams(*cls._parseURI(uri))
0155
0156 @staticmethod
0157 def _parseOldURI(uri):
0158 schema, rest = uri.split(':', 1)
0159 assert rest.startswith('/'), "URIs must start with scheme:/ -- " "you did not include a / (in %r)" % rest
0162 if rest.startswith('/') and not rest.startswith('//'):
0163 host = None
0164 rest = rest[1:]
0165 elif rest.startswith('///'):
0166 host = None
0167 rest = rest[3:]
0168 else:
0169 rest = rest[2:]
0170 if rest.find('/') == -1:
0171 host = rest
0172 rest = ''
0173 else:
0174 host, rest = rest.split('/', 1)
0175 if host and host.find('@') != -1:
0176 user, host = host.rsplit('@', 1)
0177 if user.find(':') != -1:
0178 user, password = user.split(':', 1)
0179 else:
0180 password = None
0181 else:
0182 user = password = None
0183 if host and host.find(':') != -1:
0184 _host, port = host.split(':')
0185 try:
0186 port = int(port)
0187 except ValueError:
0188 raise ValueError("port must be integer, "
0189 "got '%s' instead" % port)
0190 if not (1 <= port <= 65535):
0191 raise ValueError("port must be integer in the range 1-65535, "
0192 "got '%d' instead" % port)
0193 host = _host
0194 else:
0195 port = None
0196 path = '/' + rest
0197 if os.name == 'nt':
0198 if (len(rest) > 1) and (rest[1] == '|'):
0199 path = "%s:%s" % (rest[0], rest[2:])
0200 args = {}
0201 if path.find('?') != -1:
0202 path, arglist = path.split('?', 1)
0203 arglist = arglist.split('&')
0204 for single in arglist:
0205 argname, argvalue = single.split('=', 1)
0206 argvalue = unquote(argvalue)
0207 args[argname] = argvalue
0208 return user, password, host, port, path, args
0209
0210 @staticmethod
0211 def _parseURI(uri):
0212 parsed = urlparse(uri)
0213 if sys.version_info[0:2] == (2, 6):
0214
0215
0216
0217 scheme = parsed.scheme
0218 parsed = urlparse(uri.replace(scheme, 'http', 1))
0219 host, path = parsed.hostname, parsed.path
0220 user, password, port = None, None, None
0221 if parsed.username:
0222 user = unquote(parsed.username)
0223 if parsed.password:
0224 password = unquote(parsed.password)
0225 if parsed.port:
0226 port = int(parsed.port)
0227
0228 path = unquote(path)
0229 if (os.name == 'nt') and (len(path) > 2):
0230
0231
0232 if path[2] == '|':
0233 path = "%s:%s" % (path[0:2], path[3:])
0234
0235 if (path[0] == '/') and (path[2] == ':'):
0236 path = path[1:]
0237
0238 query = parsed.query
0239
0240
0241 args = {}
0242 if query:
0243 for name, value in parse_qsl(query):
0244 args[name] = value
0245
0246 return user, password, host, port, path, args
0247
0248 def soClassAdded(self, soClass):
0249 """
0250 This is called for each new class; we use this opportunity
0251 to create an instance method that is bound to the class
0252 and this connection.
0253 """
0254 name = soClass.__name__
0255 assert not hasattr(self, name), (
0256 "Connection %r already has an attribute with the name "
0257 "%r (and you just created the conflicting class %r)"
0258 % (self, name, soClass))
0259 setattr(self, name, ConnWrapper(soClass, self))
0260
0261 def expireAll(self):
0262 """
0263 Expire all instances of objects for this connection.
0264 """
0265 cache_set = self.cache
0266 cache_set.weakrefAll()
0267 for item in cache_set.getAll():
0268 item.expire()
0269
0270
0271class ConnWrapper(object):
0272
0273 """
0274 This represents a SQLObject class that is bound to a specific
0275 connection (instances have a connection instance variable, but
0276 classes are global, so this is binds the connection variable
0277 lazily when a class method is accessed)
0278 """
0279
0280
0281
0282
0283 def __init__(self, soClass, connection):
0284 self._soClass = soClass
0285 self._connection = connection
0286
0287 def __call__(self, *args, **kw):
0288 kw['connection'] = self._connection
0289 return self._soClass(*args, **kw)
0290
0291 def __getattr__(self, attr):
0292 meth = getattr(self._soClass, attr)
0293 if not isinstance(meth, types.MethodType):
0294
0295 return meth
0296 try:
0297 takes_conn = meth.takes_connection
0298 except AttributeError:
0299 args, varargs, varkw, defaults = inspect.getargspec(meth)
0300 assert not varkw and not varargs, (
0301 "I cannot tell whether I must wrap this method, "
0302 "because it takes **kw: %r"
0303 % meth)
0304 takes_conn = 'connection' in args
0305 meth.__func__.takes_connection = takes_conn
0306 if not takes_conn:
0307 return meth
0308 return ConnMethodWrapper(meth, self._connection)
0309
0310
0311class ConnMethodWrapper(object):
0312
0313 def __init__(self, method, connection):
0314 self._method = method
0315 self._connection = connection
0316
0317 def __getattr__(self, attr):
0318 return getattr(self._method, attr)
0319
0320 def __call__(self, *args, **kw):
0321 kw['connection'] = self._connection
0322 return self._method(*args, **kw)
0323
0324 def __repr__(self):
0325 return '<Wrapped %r with connection %r>' % (
0326 self._method, self._connection)
0327
0328
0329class DBAPI(DBConnection):
0330
0331 """
0332 Subclass must define a `makeConnection()` method, which
0333 returns a newly-created connection object.
0334
0335 ``queryInsertID`` must also be defined.
0336 """
0337 dbName = None
0338
0339 def __init__(self, **kw):
0340 self._pool = []
0341 self._poolLock = threading.Lock()
0342 DBConnection.__init__(self, **kw)
0343 self._binaryType = type(self.module.Binary(b''))
0344
0345 def _runWithConnection(self, meth, *args):
0346 conn = self.getConnection()
0347 try:
0348 val = meth(conn, *args)
0349 finally:
0350 self.releaseConnection(conn)
0351 return val
0352
0353 def getConnection(self):
0354 self._poolLock.acquire()
0355 try:
0356 if not self._pool:
0357 conn = self.makeConnection()
0358 self._connectionNumbers[id(conn)] = self._connectionCount
0359 self._connectionCount += 1
0360 else:
0361 conn = self._pool.pop()
0362 if self.debug:
0363 s = 'ACQUIRE'
0364 if self._pool is not None:
0365 s += ' pool=[%s]' % ', '.join(
0366 [str(self._connectionNumbers[id(v)])
0367 for v in self._pool])
0368 self.printDebug(conn, s, 'Pool')
0369 return conn
0370 finally:
0371 self._poolLock.release()
0372
0373 def releaseConnection(self, conn, explicit=False):
0374 if self.debug:
0375 if explicit:
0376 s = 'RELEASE (explicit)'
0377 else:
0378 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0379 if self._pool is None:
0380 s += ' no pooling'
0381 else:
0382 s += ' pool=[%s]' % ', '.join(
0383 [str(self._connectionNumbers[id(v)]) for v in self._pool])
0384 self.printDebug(conn, s, 'Pool')
0385 if self.supportTransactions and not explicit:
0386 if self.autoCommit == 'exception':
0387 if self.debug:
0388 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0389 conn.rollback()
0390 raise Exception('Object used outside of a transaction; '
0391 'implicit COMMIT or ROLLBACK not allowed')
0392 elif self.autoCommit:
0393 if self.debug:
0394 self.printDebug(conn, 'auto', 'COMMIT')
0395 if not getattr(conn, 'autocommit', False):
0396 conn.commit()
0397 else:
0398 if self.debug:
0399 self.printDebug(conn, 'auto', 'ROLLBACK')
0400 conn.rollback()
0401 if self._pool is not None:
0402 if conn not in self._pool:
0403
0404
0405
0406 self._pool.insert(0, conn)
0407 else:
0408 conn.close()
0409
0410 def printDebug(self, conn, s, name, type='query'):
0411 if name == 'Pool' and self.debug != 'Pool':
0412 return
0413 if type == 'query':
0414 sep = ': '
0415 else:
0416 sep = '->'
0417 s = repr(s)
0418 n = self._connectionNumbers[id(conn)]
0419 spaces = ' ' * (8 - len(name))
0420 if self.debugThreading:
0421 threadName = threading.currentThread().getName()
0422 threadName = (':' + threadName + ' ' * (8 - len(threadName)))
0423 else:
0424 threadName = ''
0425 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0426 self.debugWriter.write(msg)
0427
0428 def _executeRetry(self, conn, cursor, query):
0429 if self.debug:
0430 self.printDebug(conn, query, 'QueryR')
0431 return cursor.execute(query)
0432
0433 def _query(self, conn, s):
0434 if self.debug:
0435 self.printDebug(conn, s, 'Query')
0436 self._executeRetry(conn, conn.cursor(), s)
0437
0438 def query(self, s):
0439 return self._runWithConnection(self._query, s)
0440
0441 def _queryAll(self, conn, s):
0442 if self.debug:
0443 self.printDebug(conn, s, 'QueryAll')
0444 c = conn.cursor()
0445 self._executeRetry(conn, c, s)
0446 value = c.fetchall()
0447 if self.debugOutput:
0448 self.printDebug(conn, value, 'QueryAll', 'result')
0449 return value
0450
0451 def queryAll(self, s):
0452 return self._runWithConnection(self._queryAll, s)
0453
0454 def _queryAllDescription(self, conn, s):
0455 """
0456 Like queryAll, but returns (description, rows), where the
0457 description is cursor.description (which gives row types)
0458 """
0459 if self.debug:
0460 self.printDebug(conn, s, 'QueryAllDesc')
0461 c = conn.cursor()
0462 self._executeRetry(conn, c, s)
0463 value = c.fetchall()
0464 if self.debugOutput:
0465 self.printDebug(conn, value, 'QueryAll', 'result')
0466 return c.description, value
0467
0468 def queryAllDescription(self, s):
0469 return self._runWithConnection(self._queryAllDescription, s)
0470
0471 def _queryOne(self, conn, s):
0472 if self.debug:
0473 self.printDebug(conn, s, 'QueryOne')
0474 c = conn.cursor()
0475 self._executeRetry(conn, c, s)
0476 value = c.fetchone()
0477 if self.debugOutput:
0478 self.printDebug(conn, value, 'QueryOne', 'result')
0479 return value
0480
0481 def queryOne(self, s):
0482 return self._runWithConnection(self._queryOne, s)
0483
0484 def _insertSQL(self, table, names, values):
0485 return ("INSERT INTO %s (%s) VALUES (%s)" %
0486 (table, ', '.join(names),
0487 ', '.join([self.sqlrepr(v) for v in values])))
0488
0489 def transaction(self):
0490 return Transaction(self)
0491
0492 def queryInsertID(self, soInstance, id, names, values):
0493 return self._runWithConnection(self._queryInsertID, soInstance, id,
0494 names, values)
0495
0496 def iterSelect(self, select):
0497 return select.IterationClass(self, self.getConnection(),
0498 select, keepConnection=False)
0499
0500 def accumulateSelect(self, select, *expressions):
0501 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0502 to the select object.
0503 """
0504 q = select.queryForSelect().newItems(expressions). unlimited().orderBy(None)
0506 q = self.sqlrepr(q)
0507 val = self.queryOne(q)
0508 if len(expressions) == 1:
0509 val = val[0]
0510 return val
0511
0512 def queryForSelect(self, select):
0513 return self.sqlrepr(select.queryForSelect())
0514
0515 def _SO_createJoinTable(self, join):
0516 self.query(self._SO_createJoinTableSQL(join))
0517
0518 def _SO_createJoinTableSQL(self, join):
0519 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0520 (join.intermediateTable,
0521 join.joinColumn,
0522 self.joinSQLType(join),
0523 join.otherColumn,
0524 self.joinSQLType(join)))
0525
0526 def _SO_dropJoinTable(self, join):
0527 self.query("DROP TABLE %s" % join.intermediateTable)
0528
0529 def _SO_createIndex(self, soClass, index):
0530 self.query(self.createIndexSQL(soClass, index))
0531
0532 def createIndexSQL(self, soClass, index):
0533 assert 0, 'Implement in subclasses'
0534
0535 def createTable(self, soClass):
0536 createSql, constraints = self.createTableSQL(soClass)
0537 self.query(createSql)
0538
0539 return constraints
0540
0541 def createReferenceConstraints(self, soClass):
0542 refConstraints = [self.createReferenceConstraint(soClass, column)
0543 for column in soClass.sqlmeta.columnList
0544 if isinstance(column, col.SOForeignKey)]
0545 refConstraintDefs = [constraint for constraint in refConstraints
0546 if constraint]
0547 return refConstraintDefs
0548
0549 def createSQL(self, soClass):
0550 tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0551 if tableCreateSQLs:
0552 assert isinstance(tableCreateSQLs, (str, list, dict, tuple)), (
0553 '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0554 (soClass.__name__))
0555 if isinstance(tableCreateSQLs, dict):
0556 tableCreateSQLs = tableCreateSQLs.get(
0557 soClass._connection.dbName, [])
0558 if isinstance(tableCreateSQLs, str):
0559 tableCreateSQLs = [tableCreateSQLs]
0560 if isinstance(tableCreateSQLs, tuple):
0561 tableCreateSQLs = list(tableCreateSQLs)
0562 assert isinstance(tableCreateSQLs, list), (
0563 'Unable to create a list from %s.sqlmeta.createSQL' %
0564 (soClass.__name__))
0565 return tableCreateSQLs or []
0566
0567 def createTableSQL(self, soClass):
0568 constraints = self.createReferenceConstraints(soClass)
0569 extraSQL = self.createSQL(soClass)
0570 createSql = ('CREATE TABLE %s (\n%s\n)' %
0571 (soClass.sqlmeta.table, self.createColumns(soClass)))
0572 return createSql, constraints + extraSQL
0573
0574 def createColumns(self, soClass):
0575 columnDefs = [self.createIDColumn(soClass)] + [self.createColumn(soClass, col)
0577 for col in soClass.sqlmeta.columnList]
0578 return ",\n".join([" %s" % c for c in columnDefs])
0579
0580 def createReferenceConstraint(self, soClass, col):
0581 assert 0, "Implement in subclasses"
0582
0583 def createColumn(self, soClass, col):
0584 assert 0, "Implement in subclasses"
0585
0586 def dropTable(self, tableName, cascade=False):
0587 self.query("DROP TABLE %s" % tableName)
0588
0589 def clearTable(self, tableName):
0590
0591
0592
0593
0594 self.query("DELETE FROM %s" % tableName)
0595
0596 def createBinary(self, value):
0597 """
0598 Create a binary object wrapper for the given database.
0599 """
0600
0601 return self.module.Binary(value)
0602
0603
0604
0605
0606
0607
0608
0609 def _SO_update(self, so, values):
0610 self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0611 (so.sqlmeta.table,
0612 ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0613 for dbName, value in values]),
0614 so.sqlmeta.idName,
0615 self.sqlrepr(so.id)))
0616
0617 def _SO_selectOne(self, so, columnNames):
0618 return self._SO_selectOneAlt(so, columnNames, so.q.id == so.id)
0619
0620 def _SO_selectOneAlt(self, so, columnNames, condition):
0621 if columnNames:
0622 columns = [isinstance(x, string_type) and
0623 sqlbuilder.SQLConstant(x) or
0624 x for x in columnNames]
0625 else:
0626 columns = None
0627 return self.queryOne(self.sqlrepr(sqlbuilder.Select(
0628 columns, staticTables=[so.sqlmeta.table], clause=condition)))
0629
0630 def _SO_delete(self, so):
0631 self.query("DELETE FROM %s WHERE %s = (%s)" %
0632 (so.sqlmeta.table,
0633 so.sqlmeta.idName,
0634 self.sqlrepr(so.id)))
0635
0636 def _SO_selectJoin(self, soClass, column, value):
0637 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0638 (soClass.sqlmeta.idName,
0639 soClass.sqlmeta.table,
0640 column,
0641 self.sqlrepr(value)))
0642
0643 def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0644 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0645 (getColumn,
0646 table,
0647 joinColumn,
0648 self.sqlrepr(value)))
0649
0650 def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0651 secondColumn, secondValue):
0652 self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0653 (table,
0654 firstColumn,
0655 self.sqlrepr(firstValue),
0656 secondColumn,
0657 self.sqlrepr(secondValue)))
0658
0659 def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0660 secondColumn, secondValue):
0661 self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0662 (table,
0663 firstColumn,
0664 secondColumn,
0665 self.sqlrepr(firstValue),
0666 self.sqlrepr(secondValue)))
0667
0668 def _SO_columnClause(self, soClass, kw):
0669 from . import main
0670 ops = {None: "IS"}
0671 data = []
0672 if 'id' in kw:
0673 data.append((soClass.sqlmeta.idName, kw.pop('id')))
0674 for soColumn in soClass.sqlmeta.columnList:
0675 key = soColumn.name
0676 if key in kw:
0677 val = kw.pop(key)
0678 if soColumn.from_python:
0679 val = soColumn.from_python(
0680 val,
0681 sqlbuilder.SQLObjectState(soClass, connection=self))
0682 data.append((soColumn.dbName, val))
0683 elif soColumn.foreignName in kw:
0684 obj = kw.pop(soColumn.foreignName)
0685 if isinstance(obj, main.SQLObject):
0686 data.append((soColumn.dbName, obj.id))
0687 else:
0688 data.append((soColumn.dbName, obj))
0689 if kw:
0690
0691 raise TypeError("got an unexpected keyword argument(s): "
0692 "%r" % kw.keys())
0693
0694 if not data:
0695 return None
0696 return ' AND '.join(
0697 ['%s %s %s' %
0698 (dbName, ops.get(value, "="), self.sqlrepr(value))
0699 for dbName, value
0700 in data])
0701
0702 def sqlrepr(self, v):
0703 return sqlrepr(v, self.dbName)
0704
0705 def __del__(self):
0706 self.close()
0707
0708 def close(self):
0709 if not hasattr(self, '_pool'):
0710
0711
0712 return
0713 if not self._pool:
0714 return
0715 self._poolLock.acquire()
0716 try:
0717 if not self._pool:
0718 return
0719 conns = self._pool[:]
0720 self._pool[:] = []
0721 for conn in conns:
0722 try:
0723 conn.close()
0724 except self.module.Error:
0725 pass
0726 del conn
0727 del conns
0728 finally:
0729 self._poolLock.release()
0730
0731 def createEmptyDatabase(self):
0732 """
0733 Create an empty database.
0734 """
0735 raise NotImplementedError
0736
0737
0738class Iteration(object):
0739
0740 def __init__(self, dbconn, rawconn, select, keepConnection=False):
0741 self.dbconn = dbconn
0742 self.rawconn = rawconn
0743 self.select = select
0744 self.keepConnection = keepConnection
0745 self.cursor = rawconn.cursor()
0746 self.query = self.dbconn.queryForSelect(select)
0747 if dbconn.debug:
0748 dbconn.printDebug(rawconn, self.query, 'Select')
0749 self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0750
0751 def __iter__(self):
0752 return self
0753
0754 def __next__(self):
0755 return self.next()
0756
0757 def next(self):
0758 result = self.cursor.fetchone()
0759 if result is None:
0760 self._cleanup()
0761 raise StopIteration
0762 if result[0] is None:
0763 return None
0764 if self.select.ops.get('lazyColumns', 0):
0765 obj = self.select.sourceClass.get(result[0],
0766 connection=self.dbconn)
0767 return obj
0768 else:
0769 obj = self.select.sourceClass.get(result[0],
0770 selectResults=result[1:],
0771 connection=self.dbconn)
0772 return obj
0773
0774 def _cleanup(self):
0775 if getattr(self, 'query', None) is None:
0776
0777 return
0778 self.query = None
0779 if not self.keepConnection:
0780 self.dbconn.releaseConnection(self.rawconn)
0781 self.dbconn = self.rawconn = self.select = self.cursor = None
0782
0783 def __del__(self):
0784 self._cleanup()
0785
0786
0787class Transaction(object):
0788
0789 def __init__(self, dbConnection):
0790
0791 self._obsolete = True
0792 self._dbConnection = dbConnection
0793 self._connection = dbConnection.getConnection()
0794 self._dbConnection._setAutoCommit(self._connection, 0)
0795 self.cache = CacheSet(cache=dbConnection.doCache)
0796 self._deletedCache = {}
0797 self._obsolete = False
0798
0799 def assertActive(self):
0800 assert not self._obsolete, "This transaction has already gone through ROLLBACK; " "begin another transaction"
0803
0804 def query(self, s):
0805 self.assertActive()
0806 return self._dbConnection._query(self._connection, s)
0807
0808 def queryAll(self, s):
0809 self.assertActive()
0810 return self._dbConnection._queryAll(self._connection, s)
0811
0812 def queryOne(self, s):
0813 self.assertActive()
0814 return self._dbConnection._queryOne(self._connection, s)
0815
0816 def queryInsertID(self, soInstance, id, names, values):
0817 self.assertActive()
0818 return self._dbConnection._queryInsertID(
0819 self._connection, soInstance, id, names, values)
0820
0821 def iterSelect(self, select):
0822 self.assertActive()
0823
0824
0825
0826
0827
0828 return iter(list(select.IterationClass(self, self._connection,
0829 select, keepConnection=True)))
0830
0831 def _SO_delete(self, inst):
0832 cls = inst.__class__.__name__
0833 if cls not in self._deletedCache:
0834 self._deletedCache[cls] = []
0835 self._deletedCache[cls].append(inst.id)
0836 if PY2:
0837 meth = types.MethodType(self._dbConnection._SO_delete.__func__,
0838 self, self.__class__)
0839 else:
0840 meth = types.MethodType(self._dbConnection._SO_delete.__func__,
0841 self)
0842 return meth(inst)
0843
0844 def commit(self, close=False):
0845 if self._obsolete:
0846
0847 return
0848 if self._dbConnection.debug:
0849 self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0850 self._connection.commit()
0851 subCaches = [(sub[0], sub[1].allIDs())
0852 for sub in self.cache.allSubCachesByClassNames().items()]
0853 subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0854 for cls, ids in subCaches:
0855 for id in list(ids):
0856 inst = self._dbConnection.cache.tryGetByName(id, cls)
0857 if inst is not None:
0858 inst.expire()
0859 if close:
0860 self._makeObsolete()
0861
0862 def rollback(self):
0863 if self._obsolete:
0864
0865 return
0866 if self._dbConnection.debug:
0867 self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0868 subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0869 self._connection.rollback()
0870
0871 for subCache, ids in subCaches:
0872 for id in list(ids):
0873 inst = subCache.tryGet(id)
0874 if inst is not None:
0875 inst.expire()
0876 self._makeObsolete()
0877
0878 def __getattr__(self, attr):
0879 """
0880 If nothing else works, let the parent connection handle it.
0881 Except with this transaction as 'self'. Poor man's
0882 acquisition? Bad programming? Okay, maybe.
0883 """
0884 self.assertActive()
0885 attr = getattr(self._dbConnection, attr)
0886 try:
0887 func = attr.__func__
0888 except AttributeError:
0889 if isinstance(attr, ConnWrapper):
0890 return ConnWrapper(attr._soClass, self)
0891 else:
0892 return attr
0893 else:
0894 if PY2:
0895 meth = types.MethodType(func, self, self.__class__)
0896 else:
0897 meth = types.MethodType(func, self)
0898 return meth
0899
0900 def _makeObsolete(self):
0901 self._obsolete = True
0902 if self._dbConnection.autoCommit:
0903 self._dbConnection._setAutoCommit(self._connection, 1)
0904 self._dbConnection.releaseConnection(self._connection,
0905 explicit=True)
0906 self._connection = None
0907 self._deletedCache = {}
0908
0909 def begin(self):
0910
0911
0912 assert self._obsolete, "You cannot begin a new transaction session " "without rolling back this one"
0915 self._obsolete = False
0916 self._connection = self._dbConnection.getConnection()
0917 self._dbConnection._setAutoCommit(self._connection, 0)
0918
0919 def __del__(self):
0920 if self._obsolete:
0921 return
0922 self.rollback()
0923
0924 def close(self):
0925 raise TypeError('You cannot just close transaction - '
0926 'you should either call rollback(), commit() '
0927 'or commit(close=True) '
0928 'to close the underlying connection.')
0929
0930
0931class ConnectionHub(object):
0932
0933 """
0934 This object serves as a hub for connections, so that you can pass
0935 in a ConnectionHub to a SQLObject subclass as though it was a
0936 connection, but actually bind a real database connection later.
0937 You can also bind connections on a per-thread basis.
0938
0939 You must hang onto the original ConnectionHub instance, as you
0940 cannot retrieve it again from the class or instance.
0941
0942 To use the hub, do something like::
0943
0944 hub = ConnectionHub()
0945 class MyClass(SQLObject):
0946 _connection = hub
0947
0948 hub.threadConnection = connectionFromURI('...')
0949
0950 """
0951
0952 def __init__(self):
0953 self.threadingLocal = threading_local()
0954
0955 def __get__(self, obj, type=None):
0956
0957
0958
0959 if (obj is not None) and '_connection' in obj.__dict__:
0960 return obj.__dict__['_connection']
0961 return self.getConnection()
0962
0963 def __set__(self, obj, value):
0964 obj.__dict__['_connection'] = value
0965
0966 def getConnection(self):
0967 try:
0968 return self.threadingLocal.connection
0969 except AttributeError:
0970 try:
0971 return self.processConnection
0972 except AttributeError:
0973 raise AttributeError(
0974 "No connection has been defined for this thread "
0975 "or process")
0976
0977 def doInTransaction(self, func, *args, **kw):
0978 """
0979 This routine can be used to run a function in a transaction,
0980 rolling the transaction back if any exception is raised from
0981 that function, and committing otherwise.
0982
0983 Use like::
0984
0985 sqlhub.doInTransaction(process_request, os.environ)
0986
0987 This will run ``process_request(os.environ)``. The return
0988 value will be preserved.
0989 """
0990
0991
0992 try:
0993 old_conn = self.threadingLocal.connection
0994 old_conn_is_threading = True
0995 except AttributeError:
0996 old_conn = self.processConnection
0997 old_conn_is_threading = False
0998 conn = old_conn.transaction()
0999 if old_conn_is_threading:
1000 self.threadConnection = conn
1001 else:
1002 self.processConnection = conn
1003 try:
1004 try:
1005 value = func(*args, **kw)
1006 except:
1007 conn.rollback()
1008 raise
1009 else:
1010 conn.commit(close=True)
1011 return value
1012 finally:
1013 if old_conn_is_threading:
1014 self.threadConnection = old_conn
1015 else:
1016 self.processConnection = old_conn
1017
1018 def _set_threadConnection(self, value):
1019 self.threadingLocal.connection = value
1020
1021 def _get_threadConnection(self):
1022 return self.threadingLocal.connection
1023
1024 def _del_threadConnection(self):
1025 del self.threadingLocal.connection
1026
1027 threadConnection = property(_get_threadConnection,
1028 _set_threadConnection,
1029 _del_threadConnection)
1030
1031
1032class ConnectionURIOpener(object):
1033
1034 def __init__(self):
1035 self.schemeBuilders = {}
1036 self.instanceNames = {}
1037 self.cachedURIs = {}
1038
1039 def registerConnection(self, schemes, builder):
1040 for uriScheme in schemes:
1041 assert uriScheme not in self.schemeBuilders or self.schemeBuilders[uriScheme] is builder, "A driver has already been registered " "for the URI scheme %s" % uriScheme
1045 self.schemeBuilders[uriScheme] = builder
1046
1047 def registerConnectionInstance(self, inst):
1048 if inst.name:
1049 assert (inst.name not in self.instanceNames or
1050 self.instanceNames[inst.name] is cls
1051 ), ("A instance has already been registered "
1052 "with the name %s" % inst.name)
1053 assert inst.name.find(':') == -1, "You cannot include ':' " "in your class names (%r)" % cls.name
1056 self.instanceNames[inst.name] = inst
1057
1058 def connectionForURI(self, uri, oldUri=False, **args):
1059 if args:
1060 if '?' not in uri:
1061 uri += '?' + urlencode(args)
1062 else:
1063 uri += '&' + urlencode(args)
1064 if uri in self.cachedURIs:
1065 return self.cachedURIs[uri]
1066 if uri.find(':') != -1:
1067 scheme, rest = uri.split(':', 1)
1068 connCls = self.dbConnectionForScheme(scheme)
1069 if oldUri:
1070 conn = connCls.connectionFromOldURI(uri)
1071 else:
1072 conn = connCls.connectionFromURI(uri)
1073 else:
1074
1075 assert uri in self.instanceNames, "No SQLObject driver exists under the name %s" % uri
1077 conn = self.instanceNames[uri]
1078
1079 self.cachedURIs[uri] = conn
1080 return conn
1081
1082 def dbConnectionForScheme(self, scheme):
1083 assert scheme in self.schemeBuilders, (
1084 "No SQLObject driver exists for %s (only %s)" % (
1085 scheme,
1086 ', '.join(self.schemeBuilders.keys())))
1087 return self.schemeBuilders[scheme]()
1088
1089TheURIOpener = ConnectionURIOpener()
1090
1091registerConnection = TheURIOpener.registerConnection
1092registerConnectionInstance = TheURIOpener.registerConnectionInstance
1093connectionForURI = TheURIOpener.connectionForURI
1094dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1095
1096
1097
1098from . import firebird
1099from . import maxdb
1100from . import mssql
1101from . import mysql
1102from . import postgres
1103from . import rdbhost
1104from . import sqlite
1105from . import sybase