0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, connection, loglevel):
0032
0033 self.loglevel = loglevel or "stdout"
0034 self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035 def write(self, text):
0036 logfile = getattr(sys, self.loglevel)
0037 if isinstance(text, unicode):
0038 try:
0039 text = text.encode(self.dbEncoding)
0040 except UnicodeEncodeError:
0041 text = repr(text)[2:-1]
0042 logfile.write(text + '\n')
0043
0044class LogWriter:
0045 def __init__(self, connection, logger, loglevel):
0046 self.logger = logger
0047 self.loglevel = loglevel
0048 self.logmethod = getattr(logger, loglevel)
0049 def write(self, text):
0050 self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053 if not loggerName:
0054 return ConsoleWriter(connection, loglevel)
0055 import logging
0056 logger = logging.getLogger(loggerName)
0057 return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060 """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062 '0': False, 'no': False, 'false': False, 'off': False}
0063 def __new__(cls, value):
0064 try:
0065 return Boolean._keywords[value.lower()]
0066 except (AttributeError, KeyError):
0067 return bool(value)
0068
0069class DBConnection:
0070
0071 def __init__(self, name=None, debug=False, debugOutput=False,
0072 cache=True, style=None, autoCommit=True,
0073 debugThreading=False, registry=None,
0074 logger=None, loglevel=None):
0075 self.name = name
0076 self.debug = Boolean(debug)
0077 self.debugOutput = Boolean(debugOutput)
0078 self.debugThreading = Boolean(debugThreading)
0079 self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080 self.doCache = Boolean(cache)
0081 self.cache = CacheSet(cache=self.doCache)
0082 self.style = style
0083 self._connectionNumbers = {}
0084 self._connectionCount = 1
0085 self.autoCommit = Boolean(autoCommit)
0086 self.registry = registry or None
0087 classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088 registerConnectionInstance(self)
0089 atexit.register(_closeConnection, weakref.ref(self))
0090
0091 def oldUri(self):
0092 auth = getattr(self, 'user', '') or ''
0093 if auth:
0094 if self.password:
0095 auth = auth + ':' + self.password
0096 auth = auth + '@'
0097 else:
0098 assert not getattr(self, 'password', None), (
0099 'URIs cannot express passwords without usernames')
0100 uri = '%s://%s' % (self.dbName, auth)
0101 if self.host:
0102 uri += self.host
0103 if self.port:
0104 uri += ':%d' % self.port
0105 uri += '/'
0106 db = self.db
0107 if db.startswith('/'):
0108 db = db[1:]
0109 return uri + db
0110
0111 def uri(self):
0112 auth = getattr(self, 'user', '') or ''
0113 if auth:
0114 auth = urllib.quote(auth)
0115 if self.password:
0116 auth = auth + ':' + urllib.quote(self.password)
0117 auth = auth + '@'
0118 else:
0119 assert not getattr(self, 'password', None), (
0120 'URIs cannot express passwords without usernames')
0121 uri = '%s://%s' % (self.dbName, auth)
0122 if self.host:
0123 uri += self.host
0124 if self.port:
0125 uri += ':%d' % self.port
0126 uri += '/'
0127 db = self.db
0128 if db.startswith('/'):
0129 db = db[1:]
0130 return uri + urllib.quote(db)
0131
0132 @classmethod
0133 def connectionFromOldURI(cls, uri):
0134 return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136 @classmethod
0137 def connectionFromURI(cls, uri):
0138 return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140 @staticmethod
0141 def _parseOldURI(uri):
0142 schema, rest = uri.split(':', 1)
0143 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144 if rest.startswith('/') and not rest.startswith('//'):
0145 host = None
0146 rest = rest[1:]
0147 elif rest.startswith('///'):
0148 host = None
0149 rest = rest[3:]
0150 else:
0151 rest = rest[2:]
0152 if rest.find('/') == -1:
0153 host = rest
0154 rest = ''
0155 else:
0156 host, rest = rest.split('/', 1)
0157 if host and host.find('@') != -1:
0158 user, host = host.rsplit('@', 1)
0159 if user.find(':') != -1:
0160 user, password = user.split(':', 1)
0161 else:
0162 password = None
0163 else:
0164 user = password = None
0165 if host and host.find(':') != -1:
0166 _host, port = host.split(':')
0167 try:
0168 port = int(port)
0169 except ValueError:
0170 raise ValueError, "port must be integer, got '%s' instead" % port
0171 if not (1 <= port <= 65535):
0172 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0173 host = _host
0174 else:
0175 port = None
0176 path = '/' + rest
0177 if os.name == 'nt':
0178 if (len(rest) > 1) and (rest[1] == '|'):
0179 path = "%s:%s" % (rest[0], rest[2:])
0180 args = {}
0181 if path.find('?') != -1:
0182 path, arglist = path.split('?', 1)
0183 arglist = arglist.split('&')
0184 for single in arglist:
0185 argname, argvalue = single.split('=', 1)
0186 argvalue = urllib.unquote(argvalue)
0187 args[argname] = argvalue
0188 return user, password, host, port, path, args
0189
0190 @staticmethod
0191 def _parseURI(uri):
0192 protocol, request = urllib.splittype(uri)
0193 user, password, port = None, None, None
0194 host, path = urllib.splithost(request)
0195
0196 if host:
0197
0198
0199 if '@' in host:
0200 user, host = host.split('@', 1)
0201 if user:
0202 user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203 host, port = urllib.splitport(host)
0204 if port: port = int(port)
0205 elif host == '':
0206 host = None
0207
0208
0209 path, tag = urllib.splittag(path)
0210 path, query = urllib.splitquery(path)
0211
0212 path = urllib.unquote(path)
0213 if (os.name == 'nt') and (len(path) > 2):
0214
0215
0216 if path[2] == '|':
0217 path = "%s:%s" % (path[0:2], path[3:])
0218
0219 if (path[0] == '/') and (path[2] == ':'):
0220 path = path[1:]
0221
0222 args = {}
0223 if query:
0224 for name, value in parse_qsl(query):
0225 args[name] = value
0226
0227 return user, password, host, port, path, args
0228
0229 def soClassAdded(self, soClass):
0230 """
0231 This is called for each new class; we use this opportunity
0232 to create an instance method that is bound to the class
0233 and this connection.
0234 """
0235 name = soClass.__name__
0236 assert not hasattr(self, name), (
0237 "Connection %r already has an attribute with the name "
0238 "%r (and you just created the conflicting class %r)"
0239 % (self, name, soClass))
0240 setattr(self, name, ConnWrapper(soClass, self))
0241
0242 def expireAll(self):
0243 """
0244 Expire all instances of objects for this connection.
0245 """
0246 cache_set = self.cache
0247 cache_set.weakrefAll()
0248 for item in cache_set.getAll():
0249 item.expire()
0250
0251class ConnWrapper(object):
0252
0253 """
0254 This represents a SQLObject class that is bound to a specific
0255 connection (instances have a connection instance variable, but
0256 classes are global, so this is binds the connection variable
0257 lazily when a class method is accessed)
0258 """
0259
0260
0261
0262
0263 def __init__(self, soClass, connection):
0264 self._soClass = soClass
0265 self._connection = connection
0266
0267 def __call__(self, *args, **kw):
0268 kw['connection'] = self._connection
0269 return self._soClass(*args, **kw)
0270
0271 def __getattr__(self, attr):
0272 meth = getattr(self._soClass, attr)
0273 if not isinstance(meth, types.MethodType):
0274
0275 return meth
0276 try:
0277 takes_conn = meth.takes_connection
0278 except AttributeError:
0279 args, varargs, varkw, defaults = inspect.getargspec(meth)
0280 assert not varkw and not varargs, (
0281 "I cannot tell whether I must wrap this method, "
0282 "because it takes **kw: %r"
0283 % meth)
0284 takes_conn = 'connection' in args
0285 meth.im_func.takes_connection = takes_conn
0286 if not takes_conn:
0287 return meth
0288 return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292 def __init__(self, method, connection):
0293 self._method = method
0294 self._connection = connection
0295
0296 def __getattr__(self, attr):
0297 return getattr(self._method, attr)
0298
0299 def __call__(self, *args, **kw):
0300 kw['connection'] = self._connection
0301 return self._method(*args, **kw)
0302
0303 def __repr__(self):
0304 return '<Wrapped %r with connection %r>' % (
0305 self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309 """
0310 Subclass must define a `makeConnection()` method, which
0311 returns a newly-created connection object.
0312
0313 ``queryInsertID`` must also be defined.
0314 """
0315
0316 dbName = None
0317
0318 def __init__(self, **kw):
0319 self._pool = []
0320 self._poolLock = threading.Lock()
0321 DBConnection.__init__(self, **kw)
0322 self._binaryType = type(self.module.Binary(''))
0323
0324 def _runWithConnection(self, meth, *args):
0325 conn = self.getConnection()
0326 try:
0327 val = meth(conn, *args)
0328 finally:
0329 self.releaseConnection(conn)
0330 return val
0331
0332 def getConnection(self):
0333 self._poolLock.acquire()
0334 try:
0335 if not self._pool:
0336 conn = self.makeConnection()
0337 self._connectionNumbers[id(conn)] = self._connectionCount
0338 self._connectionCount += 1
0339 else:
0340 conn = self._pool.pop()
0341 if self.debug:
0342 s = 'ACQUIRE'
0343 if self._pool is not None:
0344 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345 self.printDebug(conn, s, 'Pool')
0346 return conn
0347 finally:
0348 self._poolLock.release()
0349
0350 def releaseConnection(self, conn, explicit=False):
0351 if self.debug:
0352 if explicit:
0353 s = 'RELEASE (explicit)'
0354 else:
0355 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356 if self._pool is None:
0357 s += ' no pooling'
0358 else:
0359 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360 self.printDebug(conn, s, 'Pool')
0361 if self.supportTransactions and not explicit:
0362 if self.autoCommit == 'exception':
0363 if self.debug:
0364 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365 conn.rollback()
0366 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0367 elif self.autoCommit:
0368 if self.debug:
0369 self.printDebug(conn, 'auto', 'COMMIT')
0370 if not getattr(conn, 'autocommit', False):
0371 conn.commit()
0372 else:
0373 if self.debug:
0374 self.printDebug(conn, 'auto', 'ROLLBACK')
0375 conn.rollback()
0376 if self._pool is not None:
0377 if conn not in self._pool:
0378
0379
0380
0381 self._pool.insert(0, conn)
0382 else:
0383 conn.close()
0384
0385 def printDebug(self, conn, s, name, type='query'):
0386 if name == 'Pool' and self.debug != 'Pool':
0387 return
0388 if type == 'query':
0389 sep = ': '
0390 else:
0391 sep = '->'
0392 s = repr(s)
0393 n = self._connectionNumbers[id(conn)]
0394 spaces = ' '*(8-len(name))
0395 if self.debugThreading:
0396 threadName = threading.currentThread().getName()
0397 threadName = (':' + threadName + ' '*(8-len(threadName)))
0398 else:
0399 threadName = ''
0400 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401 self.debugWriter.write(msg)
0402
0403 def _executeRetry(self, conn, cursor, query):
0404 if self.debug:
0405 self.printDebug(conn, query, 'QueryR')
0406 return cursor.execute(query)
0407
0408 def _query(self, conn, s):
0409 if self.debug:
0410 self.printDebug(conn, s, 'Query')
0411 self._executeRetry(conn, conn.cursor(), s)
0412
0413 def query(self, s):
0414 return self._runWithConnection(self._query, s)
0415
0416 def _queryAll(self, conn, s):
0417 if self.debug:
0418 self.printDebug(conn, s, 'QueryAll')
0419 c = conn.cursor()
0420 self._executeRetry(conn, c, s)
0421 value = c.fetchall()
0422 if self.debugOutput:
0423 self.printDebug(conn, value, 'QueryAll', 'result')
0424 return value
0425
0426 def queryAll(self, s):
0427 return self._runWithConnection(self._queryAll, s)
0428
0429 def _queryAllDescription(self, conn, s):
0430 """
0431 Like queryAll, but returns (description, rows), where the
0432 description is cursor.description (which gives row types)
0433 """
0434 if self.debug:
0435 self.printDebug(conn, s, 'QueryAllDesc')
0436 c = conn.cursor()
0437 self._executeRetry(conn, c, s)
0438 value = c.fetchall()
0439 if self.debugOutput:
0440 self.printDebug(conn, value, 'QueryAll', 'result')
0441 return c.description, value
0442
0443 def queryAllDescription(self, s):
0444 return self._runWithConnection(self._queryAllDescription, s)
0445
0446 def _queryOne(self, conn, s):
0447 if self.debug:
0448 self.printDebug(conn, s, 'QueryOne')
0449 c = conn.cursor()
0450 self._executeRetry(conn, c, s)
0451 value = c.fetchone()
0452 if self.debugOutput:
0453 self.printDebug(conn, value, 'QueryOne', 'result')
0454 return value
0455
0456 def queryOne(self, s):
0457 return self._runWithConnection(self._queryOne, s)
0458
0459 def _insertSQL(self, table, names, values):
0460 return ("INSERT INTO %s (%s) VALUES (%s)" %
0461 (table, ', '.join(names),
0462 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464 def transaction(self):
0465 return Transaction(self)
0466
0467 def queryInsertID(self, soInstance, id, names, values):
0468 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470 def iterSelect(self, select):
0471 return select.IterationClass(self, self.getConnection(),
0472 select, keepConnection=False)
0473
0474 def accumulateSelect(self, select, *expressions):
0475 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476 to the select object.
0477 """
0478 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479 q = self.sqlrepr(q)
0480 val = self.queryOne(q)
0481 if len(expressions) == 1:
0482 val = val[0]
0483 return val
0484
0485 def queryForSelect(self, select):
0486 return self.sqlrepr(select.queryForSelect())
0487
0488 def _SO_createJoinTable(self, join):
0489 self.query(self._SO_createJoinTableSQL(join))
0490
0491 def _SO_createJoinTableSQL(self, join):
0492 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493 (join.intermediateTable,
0494 join.joinColumn,
0495 self.joinSQLType(join),
0496 join.otherColumn,
0497 self.joinSQLType(join)))
0498
0499 def _SO_dropJoinTable(self, join):
0500 self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502 def _SO_createIndex(self, soClass, index):
0503 self.query(self.createIndexSQL(soClass, index))
0504
0505 def createIndexSQL(self, soClass, index):
0506 assert 0, 'Implement in subclasses'
0507
0508 def createTable(self, soClass):
0509 createSql, constraints = self.createTableSQL(soClass)
0510 self.query(createSql)
0511
0512 return constraints
0513
0514 def createReferenceConstraints(self, soClass):
0515 refConstraints = [self.createReferenceConstraint(soClass, column) for column in soClass.sqlmeta.columnList if isinstance(column, col.SOForeignKey)]
0518 refConstraintDefs = [constraint for constraint in refConstraints if constraint]
0521 return refConstraintDefs
0522
0523 def createSQL(self, soClass):
0524 tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525 if tableCreateSQLs:
0526 assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527 '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528 (soClass.__name__))
0529 if isinstance(tableCreateSQLs, dict):
0530 tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531 if isinstance(tableCreateSQLs, str):
0532 tableCreateSQLs = [tableCreateSQLs]
0533 if isinstance(tableCreateSQLs, tuple):
0534 tableCreateSQLs = list(tableCreateSQLs)
0535 assert isinstance(tableCreateSQLs,list), (
0536 'Unable to create a list from %s.sqlmeta.createSQL' %
0537 (soClass.__name__))
0538 return tableCreateSQLs or []
0539
0540 def createTableSQL(self, soClass):
0541 constraints = self.createReferenceConstraints(soClass)
0542 extraSQL = self.createSQL(soClass)
0543 createSql = ('CREATE TABLE %s (\n%s\n)' %
0544 (soClass.sqlmeta.table, self.createColumns(soClass)))
0545 return createSql, constraints + extraSQL
0546
0547 def createColumns(self, soClass):
0548 columnDefs = [self.createIDColumn(soClass)] + [self.createColumn(soClass, col)
0550 for col in soClass.sqlmeta.columnList]
0551 return ",\n".join([" %s" % c for c in columnDefs])
0552
0553 def createReferenceConstraint(self, soClass, col):
0554 assert 0, "Implement in subclasses"
0555
0556 def createColumn(self, soClass, col):
0557 assert 0, "Implement in subclasses"
0558
0559 def dropTable(self, tableName, cascade=False):
0560 self.query("DROP TABLE %s" % tableName)
0561
0562 def clearTable(self, tableName):
0563
0564
0565
0566
0567 self.query("DELETE FROM %s" % tableName)
0568
0569 def createBinary(self, value):
0570 """
0571 Create a binary object wrapper for the given database.
0572 """
0573
0574 return self.module.Binary(value)
0575
0576
0577
0578
0579
0580
0581
0582 def _SO_update(self, so, values):
0583 self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584 (so.sqlmeta.table,
0585 ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586 for dbName, value in values]),
0587 so.sqlmeta.idName,
0588 self.sqlrepr(so.id)))
0589
0590 def _SO_selectOne(self, so, columnNames):
0591 return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594 def _SO_selectOneAlt(self, so, columnNames, condition):
0595 if columnNames:
0596 columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597 else:
0598 columns = None
0599 return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600 staticTables=[so.sqlmeta.table],
0601 clause=condition)))
0602
0603 def _SO_delete(self, so):
0604 self.query("DELETE FROM %s WHERE %s = (%s)" %
0605 (so.sqlmeta.table,
0606 so.sqlmeta.idName,
0607 self.sqlrepr(so.id)))
0608
0609 def _SO_selectJoin(self, soClass, column, value):
0610 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611 (soClass.sqlmeta.idName,
0612 soClass.sqlmeta.table,
0613 column,
0614 self.sqlrepr(value)))
0615
0616 def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618 (getColumn,
0619 table,
0620 joinColumn,
0621 self.sqlrepr(value)))
0622
0623 def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624 secondColumn, secondValue):
0625 self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626 (table,
0627 firstColumn,
0628 self.sqlrepr(firstValue),
0629 secondColumn,
0630 self.sqlrepr(secondValue)))
0631
0632 def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633 secondColumn, secondValue):
0634 self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635 (table,
0636 firstColumn,
0637 secondColumn,
0638 self.sqlrepr(firstValue),
0639 self.sqlrepr(secondValue)))
0640
0641 def _SO_columnClause(self, soClass, kw):
0642 ops = {None: "IS"}
0643 data = {}
0644 if 'id' in kw:
0645 data[soClass.sqlmeta.idName] = kw.pop('id')
0646 for key, col in soClass.sqlmeta.columns.items():
0647 if key in kw:
0648 value = kw.pop(key)
0649 if col.from_python:
0650 value = col.from_python(value, sqlbuilder.SQLObjectState(soClass, connection=self))
0651 data[col.dbName] = value
0652 elif col.foreignName in kw:
0653 obj = kw.pop(col.foreignName)
0654 if isinstance(obj, main.SQLObject):
0655 data[col.dbName] = obj.id
0656 else:
0657 data[col.dbName] = obj
0658 if kw:
0659
0660 raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0661
0662 if not data:
0663 return None
0664 return ' AND '.join(
0665 ['%s %s %s' %
0666 (dbName, ops.get(value, "="), self.sqlrepr(value))
0667 for dbName, value
0668 in data.items()])
0669
0670 def sqlrepr(self, v):
0671 return sqlrepr(v, self.dbName)
0672
0673 def __del__(self):
0674 self.close()
0675
0676 def close(self):
0677 if not hasattr(self, '_pool'):
0678
0679
0680 return
0681 if not self._pool:
0682 return
0683 self._poolLock.acquire()
0684 try:
0685 if not self._pool:
0686 return
0687 conns = self._pool[:]
0688 self._pool[:] = []
0689 for conn in conns:
0690 try:
0691 conn.close()
0692 except self.module.Error:
0693 pass
0694 del conn
0695 del conns
0696 finally:
0697 self._poolLock.release()
0698
0699 def createEmptyDatabase(self):
0700 """
0701 Create an empty database.
0702 """
0703 raise NotImplementedError
0704
0705class Iteration(object):
0706
0707 def __init__(self, dbconn, rawconn, select, keepConnection=False):
0708 self.dbconn = dbconn
0709 self.rawconn = rawconn
0710 self.select = select
0711 self.keepConnection = keepConnection
0712 self.cursor = rawconn.cursor()
0713 self.query = self.dbconn.queryForSelect(select)
0714 if dbconn.debug:
0715 dbconn.printDebug(rawconn, self.query, 'Select')
0716 self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0717
0718 def __iter__(self):
0719 return self
0720
0721 def next(self):
0722 result = self.cursor.fetchone()
0723 if result is None:
0724 self._cleanup()
0725 raise StopIteration
0726 if result[0] is None:
0727 return None
0728 if self.select.ops.get('lazyColumns', 0):
0729 obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0730 return obj
0731 else:
0732 obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0733 return obj
0734
0735 def _cleanup(self):
0736 if getattr(self, 'query', None) is None:
0737
0738 return
0739 self.query = None
0740 if not self.keepConnection:
0741 self.dbconn.releaseConnection(self.rawconn)
0742 self.dbconn = self.rawconn = self.select = self.cursor = None
0743
0744 def __del__(self):
0745 self._cleanup()
0746
0747class Transaction(object):
0748
0749 def __init__(self, dbConnection):
0750
0751 self._obsolete = True
0752 self._dbConnection = dbConnection
0753 self._connection = dbConnection.getConnection()
0754 self._dbConnection._setAutoCommit(self._connection, 0)
0755 self.cache = CacheSet(cache=dbConnection.doCache)
0756 self._deletedCache = {}
0757 self._obsolete = False
0758
0759 def assertActive(self):
0760 assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0761
0762 def query(self, s):
0763 self.assertActive()
0764 return self._dbConnection._query(self._connection, s)
0765
0766 def queryAll(self, s):
0767 self.assertActive()
0768 return self._dbConnection._queryAll(self._connection, s)
0769
0770 def queryOne(self, s):
0771 self.assertActive()
0772 return self._dbConnection._queryOne(self._connection, s)
0773
0774 def queryInsertID(self, soInstance, id, names, values):
0775 self.assertActive()
0776 return self._dbConnection._queryInsertID(
0777 self._connection, soInstance, id, names, values)
0778
0779 def iterSelect(self, select):
0780 self.assertActive()
0781
0782
0783
0784
0785
0786 return iter(list(select.IterationClass(self, self._connection,
0787 select, keepConnection=True)))
0788
0789 def _SO_delete(self, inst):
0790 cls = inst.__class__.__name__
0791 if not cls in self._deletedCache:
0792 self._deletedCache[cls] = []
0793 self._deletedCache[cls].append(inst.id)
0794 meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0795 return meth(inst)
0796
0797 def commit(self, close=False):
0798 if self._obsolete:
0799
0800 return
0801 if self._dbConnection.debug:
0802 self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0803 self._connection.commit()
0804 subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0805 subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0806 for cls, ids in subCaches:
0807 for id in ids:
0808 inst = self._dbConnection.cache.tryGetByName(id, cls)
0809 if inst is not None:
0810 inst.expire()
0811 if close:
0812 self._makeObsolete()
0813
0814 def rollback(self):
0815 if self._obsolete:
0816
0817 return
0818 if self._dbConnection.debug:
0819 self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0820 subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0821 self._connection.rollback()
0822
0823 for subCache, ids in subCaches:
0824 for id in ids:
0825 inst = subCache.tryGet(id)
0826 if inst is not None:
0827 inst.expire()
0828 self._makeObsolete()
0829
0830 def __getattr__(self, attr):
0831 """
0832 If nothing else works, let the parent connection handle it.
0833 Except with this transaction as 'self'. Poor man's
0834 acquisition? Bad programming? Okay, maybe.
0835 """
0836 self.assertActive()
0837 attr = getattr(self._dbConnection, attr)
0838 try:
0839 func = attr.im_func
0840 except AttributeError:
0841 if isinstance(attr, ConnWrapper):
0842 return ConnWrapper(attr._soClass, self)
0843 else:
0844 return attr
0845 else:
0846 meth = new.instancemethod(func, self, self.__class__)
0847 return meth
0848
0849 def _makeObsolete(self):
0850 self._obsolete = True
0851 if self._dbConnection.autoCommit:
0852 self._dbConnection._setAutoCommit(self._connection, 1)
0853 self._dbConnection.releaseConnection(self._connection,
0854 explicit=True)
0855 self._connection = None
0856 self._deletedCache = {}
0857
0858 def begin(self):
0859
0860
0861 assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0862 self._obsolete = False
0863 self._connection = self._dbConnection.getConnection()
0864 self._dbConnection._setAutoCommit(self._connection, 0)
0865
0866 def __del__(self):
0867 if self._obsolete:
0868 return
0869 self.rollback()
0870
0871 def close(self):
0872 raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0873
0874class ConnectionHub(object):
0875
0876 """
0877 This object serves as a hub for connections, so that you can pass
0878 in a ConnectionHub to a SQLObject subclass as though it was a
0879 connection, but actually bind a real database connection later.
0880 You can also bind connections on a per-thread basis.
0881
0882 You must hang onto the original ConnectionHub instance, as you
0883 cannot retrieve it again from the class or instance.
0884
0885 To use the hub, do something like::
0886
0887 hub = ConnectionHub()
0888 class MyClass(SQLObject):
0889 _connection = hub
0890
0891 hub.threadConnection = connectionFromURI('...')
0892
0893 """
0894
0895 def __init__(self):
0896 self.threadingLocal = threading_local()
0897
0898 def __get__(self, obj, type=None):
0899
0900
0901
0902 if (obj is not None) and '_connection' in obj.__dict__:
0903 return obj.__dict__['_connection']
0904 return self.getConnection()
0905
0906 def __set__(self, obj, value):
0907 obj.__dict__['_connection'] = value
0908
0909 def getConnection(self):
0910 try:
0911 return self.threadingLocal.connection
0912 except AttributeError:
0913 try:
0914 return self.processConnection
0915 except AttributeError:
0916 raise AttributeError(
0917 "No connection has been defined for this thread "
0918 "or process")
0919
0920 def doInTransaction(self, func, *args, **kw):
0921 """
0922 This routine can be used to run a function in a transaction,
0923 rolling the transaction back if any exception is raised from
0924 that function, and committing otherwise.
0925
0926 Use like::
0927
0928 sqlhub.doInTransaction(process_request, os.environ)
0929
0930 This will run ``process_request(os.environ)``. The return
0931 value will be preserved.
0932 """
0933
0934
0935 try:
0936 old_conn = self.threadingLocal.connection
0937 old_conn_is_threading = True
0938 except AttributeError:
0939 old_conn = self.processConnection
0940 old_conn_is_threading = False
0941 conn = old_conn.transaction()
0942 if old_conn_is_threading:
0943 self.threadConnection = conn
0944 else:
0945 self.processConnection = conn
0946 try:
0947 try:
0948 value = func(*args, **kw)
0949 except:
0950 conn.rollback()
0951 raise
0952 else:
0953 conn.commit(close=True)
0954 return value
0955 finally:
0956 if old_conn_is_threading:
0957 self.threadConnection = old_conn
0958 else:
0959 self.processConnection = old_conn
0960
0961 def _set_threadConnection(self, value):
0962 self.threadingLocal.connection = value
0963
0964 def _get_threadConnection(self):
0965 return self.threadingLocal.connection
0966
0967 def _del_threadConnection(self):
0968 del self.threadingLocal.connection
0969
0970 threadConnection = property(_get_threadConnection,
0971 _set_threadConnection,
0972 _del_threadConnection)
0973
0974class ConnectionURIOpener(object):
0975
0976 def __init__(self):
0977 self.schemeBuilders = {}
0978 self.instanceNames = {}
0979 self.cachedURIs = {}
0980
0981 def registerConnection(self, schemes, builder):
0982 for uriScheme in schemes:
0983 assert not uriScheme in self.schemeBuilders or self.schemeBuilders[uriScheme] is builder, "A driver has already been registered for the URI scheme %s" % uriScheme
0986 self.schemeBuilders[uriScheme] = builder
0987
0988 def registerConnectionInstance(self, inst):
0989 if inst.name:
0990 assert not inst.name in self.instanceNames or self.instanceNames[inst.name] is cls, "A instance has already been registered with the name %s" % inst.name
0993 assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0994 self.instanceNames[inst.name] = inst
0995
0996 def connectionForURI(self, uri, oldUri=False, **args):
0997 if args:
0998 if '?' not in uri:
0999 uri += '?' + urllib.urlencode(args)
1000 else:
1001 uri += '&' + urllib.urlencode(args)
1002 if uri in self.cachedURIs:
1003 return self.cachedURIs[uri]
1004 if uri.find(':') != -1:
1005 scheme, rest = uri.split(':', 1)
1006 connCls = self.dbConnectionForScheme(scheme)
1007 if oldUri:
1008 conn = connCls.connectionFromOldURI(uri)
1009 else:
1010 conn = connCls.connectionFromURI(uri)
1011 else:
1012
1013 assert uri in self.instanceNames, "No SQLObject driver exists under the name %s" % uri
1015 conn = self.instanceNames[uri]
1016
1017 self.cachedURIs[uri] = conn
1018 return conn
1019
1020 def dbConnectionForScheme(self, scheme):
1021 assert scheme in self.schemeBuilders, (
1022 "No SQLObject driver exists for %s (only %s)"
1023 % (scheme, ', '.join(self.schemeBuilders.keys())))
1024 return self.schemeBuilders[scheme]()
1025
1026TheURIOpener = ConnectionURIOpener()
1027
1028registerConnection = TheURIOpener.registerConnection
1029registerConnectionInstance = TheURIOpener.registerConnectionInstance
1030connectionForURI = TheURIOpener.connectionForURI
1031dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1032
1033
1034import firebird
1035import maxdb
1036import mssql
1037import mysql
1038import postgres
1039import rdbhost
1040import sqlite
1041import sybase