import pickle
import re
import warnings
from functools import cached_property
import dominate.tags
import jsonpickle
from jsonpickle import Pickler
from jsonpickle.unpickler import Unpickler, loadclass
from markupsafe import Markup
from .data import SQLBase
from .utils import get_logger
logger = get_logger()
# old_loadclass = jsonpickle.unpickler.loadclass
#
#
# def check_mappers():
# # If we don't manage the imports correctly, we can end up with a nasty bug where
# # SQLAlchemy ends up registering two mappers for every class in experiment.py.
# # The following test catches such cases.
# from dallinger.db import Base
#
# animal_trial_mappers = [
# m for m in Base.registry.mappers if m.class_.__name__ == "AnimalTrial"
# ]
# assert len(animal_trial_mappers) == 1
#
#
# def loadclass(module_and_name, classes=None):
# check_mappers()
# old_loadclass(module_and_name, classes)
# check_mappers()
jsonpickle.unpickler.loadclass = loadclass
def is_lambda_function(x):
return callable(x) and hasattr(x, "__name__") and x.__name__ == "<lambda>"
class PsyNetPickler(Pickler):
def flatten(self, obj, reset=True):
if is_lambda_function(obj):
try:
source_file, source_line = (
obj.__code__.co_filename,
obj.__code__.co_firstlineno,
)
except Exception as e:
source_file, source_line = "UNKNOWN", "UNKNOWN"
logger.error(
msg="Failed to find source code for lambda function.", exc_info=e
)
raise TypeError(
"Cannot pickle lambda functions. "
"Can you replace this function with a named function defined by `def`?\n"
f"The problematic function was defined in {source_file} "
f"on line {source_line}."
)
else:
return super().flatten(obj, reset=reset)
[docs]
class PsyNetUnpickler(Unpickler):
"""
The PsyNetUnpickler class
"""
# def _restore(self, obj):
# print(obj)
# if isinstance(obj, dict) and "py/object" in obj:
# if obj["py/object"].startswith("dallinger_experiment"):
# cls = self.get_experiment_object(obj["py/object"])
# if hasattr(cls, "_sa_registry"):
# return self.load_sql_object(cls, obj)
# else:
# self.register_classes(cls)
# return super()._restore(obj)
#
# if isinstance(obj, dict) and "py/function" in obj:
# if obj["py/function"].startswith("dallinger_experiment"):
# return self.get_experiment_object(obj["py/function"])
#
# # import pydevd_pycharm
# # pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
#
# return super()._restore(obj)
def _restore_object(self, obj):
cls_id = obj["py/object"]
if cls_id.startswith("dallinger_experiment"):
cls = self.get_experiment_object(cls_id)
else:
cls = loadclass(cls_id)
is_sql_object = hasattr(cls, "_sa_registry")
if is_sql_object:
return self.load_sql_object(cls, obj)
else:
self.register_classes(cls)
return super()._restore_object(obj)
def _restore_function(self, obj):
if isinstance(obj, dict) and "py/function" in obj:
if obj["py/function"].startswith("dallinger_experiment"):
return self.get_experiment_object(obj["py/function"])
return super()._restore_function(obj)
def get_experiment_object(self, spec):
split = spec.split(".")
package_spec = split[0]
remainder_spec = split[1:]
assert package_spec == "dallinger_experiment"
current = self.experiment["package"]
for x in remainder_spec:
current = getattr(current, x)
return current
def load_sql_object(self, cls, obj):
identifiers = obj["identifiers"]
res = cls.query.filter_by(**identifiers).one_or_none()
if res is None:
warnings.warn(
f"The unserializer failed to find the following object in the database: {obj}. "
"Returning `None` instead."
)
return res
@cached_property
def experiment(self):
from .experiment import import_local_experiment
return import_local_experiment()
def serialize(x, **kwargs):
pickler = PsyNetPickler()
return jsonpickle.encode(x, **kwargs, context=pickler, warn=True)
def to_dict(x):
pickler = PsyNetPickler()
return pickler.flatten(x)
def unserialize(x):
# If we don't provide the custom classes directly, jsonpickle tries to find them itself,
# and ends up messing up the SQLAlchemy mapper registration system,
# producing duplicate mappers for each custom class.
# import_local_experiment()
# custom_classes = list(get_custom_sql_classes().values())
# return jsonpickle.decode(x, context=unpickler, classes=custom_classes)
unpickler = PsyNetUnpickler()
return jsonpickle.decode(x, context=unpickler)
# return jsonpickle.decode(x, classes=custom_classes)
# These classes cannot be reliably pickled by the `jsonpickle` library.
# Instead we fall back to Python's built-in pickle library.
no_json_classes = [Markup]
[docs]
class NoJSONHandler(jsonpickle.handlers.BaseHandler):
"""
The NoJSONHandler class
"""
[docs]
def flatten(self, obj, state):
state["bytes"] = pickle.dumps(obj, 0).decode("latin-1")
return state
[docs]
def restore(self, state):
return pickle.loads(state["bytes"].encode("latin-1"))
for _cls in no_json_classes:
jsonpickle.register(_cls, NoJSONHandler, base=True)
[docs]
class SQLHandler(jsonpickle.handlers.BaseHandler):
"""
The SQLHandler class
"""
def get_primary_keys(self, obj):
primary_key_cols = [c.name for c in obj.__class__.__table__.primary_key.columns]
return {key: getattr(obj, key) for key in primary_key_cols}
[docs]
def flatten(self, obj, state):
primary_keys = self.get_primary_keys(obj)
if any(key is None for key in primary_keys.values()):
raise ValueError(
f"Cannot serialize {obj}. It has a `None` value for one of its primary keys: {primary_keys}. "
"It might be possible to solve this problem by introducing a `db.session.flush()` call before pickling."
)
state["identifiers"] = primary_keys
return state
[docs]
def restore(self, state):
from .experiment import import_local_experiment
raise RuntimeError("This should not be called directly")
cls_definition = state["py/object"]
is_custom_cls = cls_definition.startswith("dallinger_experiment")
if is_custom_cls:
cls_name = re.sub(".*\\.", "", cls_definition)
exp = import_local_experiment()
cls = getattr(exp["module"], cls_name)
else:
cls = loadclass(state["py/object"])
identifiers = state["identifiers"]
return cls.query.filter_by(**identifiers).one()
jsonpickle.register(SQLBase, SQLHandler, base=True)
[docs]
class DominateHandler(jsonpickle.handlers.BaseHandler):
"""
The DominateHandler class
"""
[docs]
def flatten(self, obj, state):
return str(obj)
jsonpickle.register(dominate.dom_tag.dom_tag, DominateHandler, base=True)