0001import dbconnection
0002import joins
0003import main
0004import sqlbuilder
0005
0006__all__ = ['SelectResults']
0007
0008class SelectResults(object):
0009 IterationClass = dbconnection.Iteration
0010
0011 def __init__(self, sourceClass, clause, clauseTables=None,
0012 **ops):
0013 self.sourceClass = sourceClass
0014 if clause is None or isinstance(clause, str) and clause == 'all':
0015 clause = sqlbuilder.SQLTrueClause
0016 if not isinstance(clause, sqlbuilder.SQLExpression):
0017 clause = sqlbuilder.SQLConstant(clause)
0018 self.clause = clause
0019 self.ops = ops
0020 if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0021 ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0022 orderBy = ops['orderBy']
0023 if isinstance(orderBy, (tuple, list)):
0024 orderBy = map(self._mungeOrderBy, orderBy)
0025 else:
0026 orderBy = self._mungeOrderBy(orderBy)
0027 ops['dbOrderBy'] = orderBy
0028 if 'connection' in ops and ops['connection'] is None:
0029 del ops['connection']
0030 if ops.get('limit', None):
0031 assert not ops.get('start', None) and not ops.get('end', None), "'limit' cannot be used with 'start' or 'end'"
0033 ops["start"] = 0
0034 ops["end"] = ops.pop("limit")
0035
0036 tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
0037 if clauseTables:
0038 for table in clauseTables:
0039 tablesSet.add(table)
0040 self.clauseTables = clauseTables
0041
0042 self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0043
0044 def queryForSelect(self):
0045 columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
0046 query = sqlbuilder.Select(columns,
0047 where=self.clause,
0048 join=self.ops.get('join', sqlbuilder.NoDefault),
0049 distinct=self.ops.get('distinct',False),
0050 lazyColumns=self.ops.get('lazyColumns', False),
0051 start=self.ops.get('start', 0),
0052 end=self.ops.get('end', None),
0053 orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
0054 reversed=self.ops.get('reversed', False),
0055 staticTables=self.tables,
0056 forUpdate=self.ops.get('forUpdate', False))
0057 return query
0058
0059 def __repr__(self):
0060 return "<%s at %x>" % (self.__class__.__name__, id(self))
0061
0062 def _getConnection(self):
0063 return self.ops.get('connection') or self.sourceClass._connection
0064
0065 def __str__(self):
0066 conn = self._getConnection()
0067 return conn.queryForSelect(self)
0068
0069 def _mungeOrderBy(self, orderBy):
0070 if isinstance(orderBy, basestring) and orderBy.startswith('-'):
0071 orderBy = orderBy[1:]
0072 desc = True
0073 else:
0074 desc = False
0075 if isinstance(orderBy, basestring):
0076 if orderBy in self.sourceClass.sqlmeta.columns:
0077 val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
0078 if desc:
0079 return sqlbuilder.DESC(val)
0080 else:
0081 return val
0082 else:
0083 orderBy = sqlbuilder.SQLConstant(orderBy)
0084 if desc:
0085 return sqlbuilder.DESC(orderBy)
0086 else:
0087 return orderBy
0088 else:
0089 return orderBy
0090
0091 def clone(self, **newOps):
0092 ops = self.ops.copy()
0093 ops.update(newOps)
0094 return self.__class__(self.sourceClass, self.clause,
0095 self.clauseTables, **ops)
0096
0097 def orderBy(self, orderBy):
0098 return self.clone(orderBy=orderBy)
0099
0100 def connection(self, conn):
0101 return self.clone(connection=conn)
0102
0103 def limit(self, limit):
0104 return self[:limit]
0105
0106 def lazyColumns(self, value):
0107 return self.clone(lazyColumns=value)
0108
0109 def reversed(self):
0110 return self.clone(reversed=not self.ops.get('reversed', False))
0111
0112 def distinct(self):
0113 return self.clone(distinct=True)
0114
0115 def newClause(self, new_clause):
0116 return self.__class__(self.sourceClass, new_clause,
0117 self.clauseTables, **self.ops)
0118
0119 def filter(self, filter_clause):
0120 if filter_clause is None:
0121
0122 return self
0123 clause = self.clause
0124 if isinstance(clause, basestring):
0125 clause = sqlbuilder.SQLConstant('(%s)' % clause)
0126 return self.newClause(sqlbuilder.AND(clause, filter_clause))
0127
0128 def __getitem__(self, value):
0129 if isinstance(value, slice):
0130 assert not value.step, "Slices do not support steps"
0131 if not value.start and not value.stop:
0132
0133 return self
0134
0135
0136
0137
0138 if (value.start and value.start < 0) or (value.stop and value.stop < 0):
0140 if value.start:
0141 if value.stop:
0142 return list(self)[value.start:value.stop]
0143 return list(self)[value.start:]
0144 return list(self)[:value.stop]
0145
0146
0147 if value.start:
0148 assert value.start >= 0
0149 start = self.ops.get('start', 0) + value.start
0150 if value.stop is not None:
0151 assert value.stop >= 0
0152 if value.stop < value.start:
0153
0154 end = start
0155 else:
0156 end = value.stop + self.ops.get('start', 0)
0157 if self.ops.get('end', None) is not None and self.ops['end'] < end:
0159
0160 end = self.ops['end']
0161 else:
0162 end = self.ops.get('end', None)
0163 else:
0164 start = self.ops.get('start', 0)
0165 end = value.stop + start
0166 if self.ops.get('end', None) is not None and self.ops['end'] < end:
0168 end = self.ops['end']
0169 return self.clone(start=start, end=end)
0170 else:
0171 if value < 0:
0172 return list(iter(self))[value]
0173 else:
0174 start = self.ops.get('start', 0) + value
0175 return list(self.clone(start=start, end=start+1))[0]
0176
0177 def __iter__(self):
0178
0179
0180
0181 return iter(list(self.lazyIter()))
0182
0183 def lazyIter(self):
0184 """
0185 Returns an iterator that will lazily pull rows out of the
0186 database and return SQLObject instances
0187 """
0188 conn = self._getConnection()
0189 return conn.iterSelect(self)
0190
0191 def accumulate(self, *expressions):
0192 """ Use accumulate expression(s) to select result
0193 using another SQL select through current
0194 connection.
0195 Return the accumulate result
0196 """
0197 conn = self._getConnection()
0198 exprs = []
0199 for expr in expressions:
0200 if not isinstance(expr, sqlbuilder.SQLExpression):
0201 expr = sqlbuilder.SQLConstant(expr)
0202 exprs.append(expr)
0203 return conn.accumulateSelect(self, *exprs)
0204
0205 def count(self):
0206 """ Counting elements of current select results """
0207 assert not self.ops.get('start') and not self.ops.get('end'), "start/end/limit have no meaning with 'count'"
0209 assert not (self.ops.get('distinct') and (self.ops.get('start')
0210 or self.ops.get('end'))), "distinct-counting of sliced objects is not supported"
0212 if self.ops.get('distinct'):
0213
0214
0215
0216
0217 count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
0218 else:
0219 count = self.accumulate('COUNT(*)')
0220 if self.ops.get('start'):
0221 count -= self.ops['start']
0222 if self.ops.get('end'):
0223 count = min(self.ops['end'] - self.ops.get('start', 0), count)
0224 return count
0225
0226 def accumulateMany(self, *attributes):
0227 """ Making the expressions for count/sum/min/max/avg
0228 of a given select result attributes.
0229 `attributes` must be a list/tuple of pairs (func_name, attribute);
0230 `attribute` can be a column name (like 'a_column')
0231 or a dot-q attribute (like Table.q.aColumn)
0232 """
0233 expressions = []
0234 conn = self._getConnection()
0235 if self.ops.get('distinct'):
0236 distinct = 'DISTINCT '
0237 else:
0238 distinct = ''
0239 for func_name, attribute in attributes:
0240 if not isinstance(attribute, str):
0241 attribute = conn.sqlrepr(attribute)
0242 expression = '%s(%s%s)' % (func_name, distinct, attribute)
0243 expressions.append(expression)
0244 return self.accumulate(*expressions)
0245
0246 def accumulateOne(self, func_name, attribute):
0247 """ Making the sum/min/max/avg of a given select result attribute.
0248 `attribute` can be a column name (like 'a_column')
0249 or a dot-q attribute (like Table.q.aColumn)
0250 """
0251 return self.accumulateMany((func_name, attribute))
0252
0253 def sum(self, attribute):
0254 return self.accumulateOne("SUM", attribute)
0255
0256 def min(self, attribute):
0257 return self.accumulateOne("MIN", attribute)
0258
0259 def avg(self, attribute):
0260 return self.accumulateOne("AVG", attribute)
0261
0262 def max(self, attribute):
0263 return self.accumulateOne("MAX", attribute)
0264
0265 def getOne(self, default=sqlbuilder.NoDefault):
0266 """
0267 If a query is expected to only return a single value,
0268 using ``.getOne()`` will return just that value.
0269
0270 If not results are found, ``SQLObjectNotFound`` will be
0271 raised, unless you pass in a default value (like
0272 ``.getOne(None)``).
0273
0274 If more than one result is returned,
0275 ``SQLObjectIntegrityError`` will be raised.
0276 """
0277 results = list(self)
0278 if not results:
0279 if default is sqlbuilder.NoDefault:
0280 raise main.SQLObjectNotFound(
0281 "No results matched the query for %s"
0282 % self.sourceClass.__name__)
0283 return default
0284 if len(results) > 1:
0285 raise main.SQLObjectIntegrityError(
0286 "More than one result returned from query: %s"
0287 % results)
0288 return results[0]
0289
0290 def throughTo(self):
0291 class _throughTo_getter(object):
0292 def __init__(self, inst):
0293 self.sresult = inst
0294 def __getattr__(self, attr):
0295 return self.sresult._throughTo(attr)
0296 return _throughTo_getter(self)
0297 throughTo = property(throughTo)
0298
0299 def _throughTo(self, attr):
0300 otherClass = None
0301 orderBy = sqlbuilder.NoDefault
0302
0303 ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
0304 if ref and ref.foreignKey:
0305 otherClass, clause = self._throughToFK(ref)
0306 else:
0307 join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
0308 if join:
0309 join = join[0]
0310 orderBy = join.orderBy
0311 if hasattr(join, 'otherColumn'):
0312 otherClass, clause = self._throughToRelatedJoin(join)
0313 else:
0314 otherClass, clause = self._throughToMultipleJoin(join)
0315
0316 if not otherClass:
0317 raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
0318
0319 return otherClass.select(clause,
0320 orderBy=orderBy,
0321 connection=self._getConnection())
0322
0323 def _throughToFK(self, col):
0324 otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
0325 colName = col.name
0326 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
0327 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
0328 return otherClass, otherClass.q.id==getattr(query.q, colName)
0329
0330 def _throughToMultipleJoin(self, join):
0331 otherClass = join.otherClass
0332 colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
0333 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0334 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0335 joinColumn = getattr(otherClass.q, colName)
0336 return otherClass, joinColumn==query.q.id
0337
0338 def _throughToRelatedJoin(self, join):
0339 otherClass = join.otherClass
0340 intTable = sqlbuilder.Table(join.intermediateTable)
0341 colName = join.joinColumn
0342 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0343 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0344 clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
0345 getattr(intTable, colName) == query.q.id)
0346 return otherClass, clause