# coding=utf-8
"""
Create all standard extensions.
"""
# Note: Because of issues with circular dependencies, Abilian-specific
# extensions are created later.
from __future__ import absolute_import, print_function, division
from abilian.core.logging import patch_logger
from sqlalchemy.engine import Engine
from flask import current_app
from . import upstream_info
from .login import login_manager
__all__ = ['get_extension', 'db', 'mail', 'login_manager', 'csrf',
'upstream_info']
# Standard extensions.
import flask_mail
# patch flask.ext.mail.Message.send to always set enveloppe_from default mail
# sender
# FIXME: we'ld rather subclass Message and update all imports
def _message_send(self, connection):
"""
Sends a single message instance. If TESTING is True the message will
not actually be sent.
:param message: a Message instance.
"""
sender = current_app.config['MAIL_SENDER']
if not self.extra_headers:
self.extra_headers = {}
self.extra_headers['Sender'] = sender
connection.send(self, sender)
patch_logger.info(flask_mail.Message.send)
flask_mail.Message.send = _message_send
mail = flask_mail.Mail()
import sqlalchemy as sa
from ..sqlalchemy import SQLAlchemy
db = SQLAlchemy()
@sa.event.listens_for(db.metadata, 'before_create')
@sa.event.listens_for(db.metadata, 'before_drop')
def _filter_metadata_for_connection(target, connection, **kw):
"""
listener to control what indexes get created.
Useful for skipping postgres-specific indexes on a sqlite for example.
It's looking for info entry `engines` on an index
(`Index(info=dict(engines=['postgresql']))`), an iterable of engine names.
"""
engine = connection.engine.name
default_engines = (engine,)
tables = target if isinstance(target, sa.Table) else kw.get('tables', [])
for table in tables:
indexes = list(table.indexes)
for idx in indexes:
if engine not in idx.info.get('engines', default_engines):
table.indexes.remove(idx)
# csrf
from .csrf import wtf_csrf as csrf, abilian_csrf
[docs]def get_extension(name):
"""Get the named extension from the current app, returning None if not found.
"""
from flask import current_app
return current_app.extensions.get(name)
def _install_get_display_value(cls):
_MARK = object()
def display_value(self, field_name, value=_MARK):
""" Return display value for fields having 'choices' mapping (stored value
-> human readable value). For other fields it will simply return field
value.
`display_value` should be used instead of directly getting field value.
If `value` is provided it is "tranlated" to a human-readable value. This is
useful for obtaining a human readable label from a raw value
"""
val = getattr(self, field_name) if value is _MARK else value
mapper = sa.orm.object_mapper(self)
try:
field = getattr(mapper.c, field_name)
except AttributeError:
pass
else:
if 'choices' in field.info:
get = lambda v: field.info['choices'].get(v, v)
if isinstance(val, list):
val = [get(v) for v in val]
else:
val = get(val)
return val
if not hasattr(cls, 'display_value'):
cls.display_value = display_value
sa.event.listen(db.Model, 'class_instrument', _install_get_display_value)
#
# Make Sqlite a bit more well-behaved.
#
@sa.event.listens_for(Engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
from sqlite3 import Connection as SQLite3Connection
if isinstance(dbapi_connection, SQLite3Connection): # pragma: no cover
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON;")
cursor.close()