# coding=utf-8
""" Additional data types for sqlalchemy
"""
from __future__ import absolute_import, print_function, division
import sys
import logging
from distutils.version import StrictVersion
import pkg_resources
from functools import partial
import json
import uuid
import pytz
import babel
from flask_sqlalchemy import SQLAlchemy as SAExtension
import sqlalchemy as sa
from sqlalchemy.ext.mutable import Mutable
from .logging import patch_logger
logger = logging.getLogger(__name__)
FLASK_SA_VERSION = pkg_resources.get_distribution('Flask-SQLAlchemy').version
@sa.event.listens_for(sa.pool.Pool, "checkout")
[docs]def ping_connection(dbapi_connection, connection_record, connection_proxy):
"""
Ensure connections are valid.
From: `http://docs.sqlalchemy.org/en/rel_0_8/core/pooling.html`
In case db has been restarted pool may return invalid connections.
"""
cursor = dbapi_connection.cursor()
try:
cursor.execute("SELECT 1")
except:
# optional - dispose the whole pool
# instead of invalidating one at a time
# connection_proxy._pool.dispose()
# raise DisconnectionError - pool will try
# connecting again up to three times before raising.
raise sa.exc.DisconnectionError()
cursor.close()
[docs]class AbilianBaseSAExtension(SAExtension):
"""
Base subclass of :class:`flask_sqlalchemy.SQLAlchemy`. Add
our custom driver hacks.
"""
[docs] def apply_driver_hacks(self, app, info, options):
SAExtension.apply_driver_hacks(self, app, info, options)
if info.drivername == 'sqlite':
connect_args = options.setdefault('connect_args', {})
if 'isolation_level' not in connect_args:
# required to support savepoints/rollback without error. It disables
# implicit BEGIN/COMMIT statements made by pysqlite (a COMMIT kills all
# savepoints made).
connect_args['isolation_level'] = None
elif info.drivername.startswith('postgres'):
options.setdefault('client_encoding', 'utf8')
if StrictVersion(FLASK_SA_VERSION) <= StrictVersion('1.0'):
# SA extension's scoped session supports 'bind' parameter only after 1.0. This
# is a fix for it. This is required to ensure transaction rollback during
# tests, but it's useful in some use cases too.
from flask_sqlalchemy import _SignallingSession as BaseSession
class SignallingSession(BaseSession):
def __init__(self, db, autocommit=False, autoflush=True, **options):
self.app = db.get_app()
self._model_changes = {}
bind = options.pop('bind', None) or db.engine
# actually we are overriding BaseSession.__init__, so we don't want to
# call it! Directly call BaseSession parent __init__
sa.orm.Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
bind=bind, binds=db.get_binds(self.app),
**options)
[docs] class SQLAlchemy(AbilianBaseSAExtension):
def create_scoped_session(self, options=None):
"""Helper factory method that creates a scoped session."""
# override needed to use our SignallingSession implementation
if options is None:
options = {}
scopefunc = options.pop('scopefunc', None)
return sa.orm.scoped_session(partial(SignallingSession, self, **options),
scopefunc=scopefunc)
else:
# Flask-SQLAlchemy > 1.0: bind parameter is supported
SQLAlchemy = AbilianBaseSAExtension
del FLASK_SA_VERSION
# PATCH flask_sqlalchemy for proper info in debug toolbar.
#
# Original code works only when current app code is involved. If using 3rd party
# app the query is logged but source is marked "unknown". Our patch is a "best
# guess".
def _calling_context(app_path):
frm = sys._getframe(1)
entered_sa_code = exited_sa_code = False
sa_caller = '<unknown>'
format_name = ('{frm.f_code.co_filename}:{frm.f_lineno} '
'({frm.f_code.co_name})'.format)
while frm.f_back is not None:
name = frm.f_globals.get('__name__')
if name and (name == app_path or name.startswith(app_path + '.')):
return format_name(frm=frm)
if not exited_sa_code:
in_sa_code = name and (name == 'sqlalchemy'
or name.startswith('sqlalchemy.'))
if not entered_sa_code:
entered_sa_code = in_sa_code
elif not in_sa_code:
# exited from sa stack: retain name
sa_caller = format_name(frm=frm)
exited_sa_code = True
frm = frm.f_back
return sa_caller
import flask_sqlalchemy as flask_sa
patch_logger.info(flask_sa._calling_context)
flask_sa._calling_context = _calling_context
del flask_sa
# END PATCH
[docs]def filter_cols(model, *filtered_columns):
"""
Return columnsnames for a model except named ones. Useful for defer()
for example to retain only columns of interest
"""
m = sa.orm.class_mapper(model)
return list(set(p.key for p in m.iterate_properties
if hasattr(p, 'columns')).difference(filtered_columns))
[docs]class MutationDict(Mutable, dict):
"""Provides a dictionary type with mutability support."""
@classmethod
[docs] def coerce(cls, key, value):
"""Convert plain dictionaries to MutationDict."""
if not isinstance(value, MutationDict):
if isinstance(value, dict):
return MutationDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
# pickling support. see:
# http://docs.sqlalchemy.org/en/rel_0_8/orm/extensions/mutable.html#supporting-pickling
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(state)
# dict methods
def __setitem__(self, key, value):
"""Detect dictionary set events and emit change events."""
dict.__setitem__(self, key, value)
self.changed()
def __delitem__(self, key):
"""Detect dictionary del events and emit change events."""
dict.__delitem__(self, key)
self.changed()
[docs] def clear(self):
dict.clear(self)
self.changed()
[docs] def update(self, other):
dict.update(self, other)
self.changed()
[docs] def setdefault(self, key, failobj=None):
if key not in self:
self.changed()
return dict.setdefault(self, key, failobj)
[docs] def pop(self, key, *args):
self.changed()
return dict.pop(self, key, *args)
[docs] def popitem(self):
self.changed()
return dict.popitem(self)
[docs]class MutationList(Mutable, list):
"""
Provides a list type with mutability support.
"""
@classmethod
[docs] def coerce(cls, key, value):
"""Convert list to MutationList."""
if not isinstance(value, MutationList):
if isinstance(value, list):
return MutationList(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
# pickling support. see:
# http://docs.sqlalchemy.org/en/rel_0_8/orm/extensions/mutable.html#supporting-pickling
def __getstate__(self):
d = self.__dict__.copy()
d.pop('_parents', None)
return d
# list methods
def __setitem__(self, idx, value):
list.__setitem__(self, idx, value)
self.changed()
def __delitem__(self, idx):
list.__delitem__(self, idx)
self.changed()
[docs] def insert(self, idx, value):
list.insert(self, idx, value)
self.changed()
def __setslice__(self, i, j, other):
list.__setslice__(self, i, j, other)
self.changed()
def __delslice__(self, i, j):
list.__delslice__(self, i, j)
self.changed()
def __iadd__(self, other):
l = list.__iadd__(self, other)
self.changed()
return l
def __imul__(self, n):
l = list.__imul__(self, n)
self.changed()
return l
[docs] def append(self, item):
list.append(self, item)
self.changed()
[docs] def pop(self, i=-1):
item = list.pop(self, i)
self.changed()
return item
[docs] def remove(self, item):
list.remove(self, item)
self.changed()
[docs] def reverse(self):
list.reverse(self)
self.changed()
[docs] def sort(self, *args, **kwargs):
list.sort(self, *args, **kwargs)
self.changed()
[docs] def extend(self, other):
list.extend(self, other)
self.changed()
[docs]class JSON(sa.types.TypeDecorator):
"""Stores any structure serializable with json.
Usage
JSON()
Takes same parameters as sqlalchemy.types.Text
"""
impl = sa.types.Text
[docs] def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value)
return value
[docs] def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value)
return value
[docs]class JSONUniqueListType(JSON):
""" Store a list in JSON format, with items made unique and sorted.
"""
@property
def python_type(self):
return MutationList
[docs] def process_bind_param(self, value, dialect):
# value may be a simple string used in a LIKE clause for instance, so we
# must ensure we uniquify/sort only for list-like values
if value is not None and isinstance(value, (tuple, list)):
value = sorted(set(value))
return JSON.process_bind_param(self, value, dialect)
[docs]def JSONDict(*args, **kwargs):
"""
Stores a dict as JSON on database, with mutability support.
"""
return MutationDict.as_mutable(JSON(*args, **kwargs))
[docs]def JSONList(*args, **kwargs):
"""
Stores a list as JSON on database, with mutability support.
If kwargs has a param `unique_sorted` (which evaluated to True), list values
are made unique and sorted.
"""
type_ = JSON
try:
if kwargs.pop('unique_sorted'):
type_ = JSONUniqueListType
except KeyError:
pass
return MutationList.as_mutable(type_(*args, **kwargs))
[docs]class UUID(sa.types.TypeDecorator):
"""
Platform-independent UUID type.
Uses Postgresql's UUID type, otherwise uses
CHAR(32), storing as stringified hex values.
From SQLAlchemy documentation.
"""
impl = sa.types.CHAR
[docs] def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(sa.dialects.postgresql.UUID())
else:
return dialect.type_descriptor(sa.types.CHAR(32))
[docs] def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
# hexstring
return "%.32x" % value
[docs] def process_result_value(self, value, dialect):
return value if value is None else uuid.UUID(value)
[docs] def compare_against_backend(self, dialect, conn_type):
if dialect.name == 'postgresql':
return isinstance(conn_type, sa.dialects.postgresql.UUID)
else:
return isinstance(conn_type, sa.types.CHAR)
[docs]class Locale(sa.types.TypeDecorator):
"""
Store a :class:`babel.Locale` instance
"""
impl = sa.types.UnicodeText
@property
def python_type(self):
return babel.Locale
[docs] def process_bind_param(self, value, dialect):
if value is None:
return None
if not isinstance(value, babel.Locale):
if not isinstance(value, basestring):
raise ValueError("Unknown locale value: %s" % repr(value))
if not value.strip():
return None
value = babel.Locale.parse(value)
code = unicode(value.language)
if value.territory:
code += u'_' + unicode(value.territory)
elif value.script:
code += u'_' + unicode(value.territory)
return code
[docs] def process_result_value(self, value, dialect):
return None if value is None else babel.Locale.parse(value)
[docs]class Timezone(sa.types.TypeDecorator):
"""
Store a :class:`pytz.tzfile.DstTzInfo` instance
"""
impl = sa.types.UnicodeText
@property
def python_type(self):
return pytz.tzfile.DstTzInfo
[docs] def process_bind_param(self, value, dialect):
if value is None:
return None
if not isinstance(value, pytz.tzfile.DstTzInfo):
if not isinstance(value, basestring):
raise ValueError("Unknown timezone value: %s" % repr(value))
if not value.strip():
return None
value = babel.dates.get_timezone(value)
return value.zone
[docs] def process_result_value(self, value, dialect):
return None if value is None else babel.dates.get_timezone(value)