Source code for psynet.serialize

import pickle
import re
import warnings
from functools import cached_property

import dominate.tags
import jsonpickle
from jsonpickle import Pickler
from jsonpickle.unpickler import Unpickler, loadclass
from markupsafe import Markup

from .data import SQLBase
from .utils import get_logger

logger = get_logger()

# old_loadclass = jsonpickle.unpickler.loadclass
#
#
# def check_mappers():
#     # If we don't manage the imports correctly, we can end up with a nasty bug where
#     # SQLAlchemy ends up registering two mappers for every class in experiment.py.
#     # The following test catches such cases.
#     from dallinger.db import Base
#
#     animal_trial_mappers = [
#         m for m in Base.registry.mappers if m.class_.__name__ == "AnimalTrial"
#     ]
#     assert len(animal_trial_mappers) == 1
#
#
# def loadclass(module_and_name, classes=None):
#     check_mappers()
#     old_loadclass(module_and_name, classes)
#     check_mappers()


jsonpickle.unpickler.loadclass = loadclass


def is_lambda_function(x):
    return callable(x) and hasattr(x, "__name__") and x.__name__ == "<lambda>"


class PsyNetPickler(Pickler):
    def flatten(self, obj, reset=True):
        if is_lambda_function(obj):
            try:
                source_file, source_line = (
                    obj.__code__.co_filename,
                    obj.__code__.co_firstlineno,
                )
            except Exception as e:
                source_file, source_line = "UNKNOWN", "UNKNOWN"
                logger.error(
                    msg="Failed to find source code for lambda function.", exc_info=e
                )
            raise TypeError(
                "Cannot pickle lambda functions. "
                "Can you replace this function with a named function defined by `def`?\n"
                f"The problematic function was defined in {source_file} "
                f"on line {source_line}."
            )
        else:
            return super().flatten(obj, reset=reset)


[docs] class PsyNetUnpickler(Unpickler): """ The PsyNetUnpickler class """ # def _restore(self, obj): # print(obj) # if isinstance(obj, dict) and "py/object" in obj: # if obj["py/object"].startswith("dallinger_experiment"): # cls = self.get_experiment_object(obj["py/object"]) # if hasattr(cls, "_sa_registry"): # return self.load_sql_object(cls, obj) # else: # self.register_classes(cls) # return super()._restore(obj) # # if isinstance(obj, dict) and "py/function" in obj: # if obj["py/function"].startswith("dallinger_experiment"): # return self.get_experiment_object(obj["py/function"]) # # # import pydevd_pycharm # # pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True) # # return super()._restore(obj) def _restore_object(self, obj): cls_id = obj["py/object"] if cls_id.startswith("dallinger_experiment"): cls = self.get_experiment_object(cls_id) else: cls = loadclass(cls_id) is_sql_object = hasattr(cls, "_sa_registry") if is_sql_object: return self.load_sql_object(cls, obj) else: self.register_classes(cls) return super()._restore_object(obj) def _restore_function(self, obj): if isinstance(obj, dict) and "py/function" in obj: if obj["py/function"].startswith("dallinger_experiment"): return self.get_experiment_object(obj["py/function"]) return super()._restore_function(obj) def get_experiment_object(self, spec): split = spec.split(".") package_spec = split[0] remainder_spec = split[1:] assert package_spec == "dallinger_experiment" current = self.experiment["package"] for x in remainder_spec: current = getattr(current, x) return current def load_sql_object(self, cls, obj): identifiers = obj["identifiers"] res = cls.query.filter_by(**identifiers).one_or_none() if res is None: warnings.warn( f"The unserializer failed to find the following object in the database: {obj}. " "Returning `None` instead." ) return res @cached_property def experiment(self): from .experiment import import_local_experiment return import_local_experiment()
def serialize(x, **kwargs): pickler = PsyNetPickler() return jsonpickle.encode(x, **kwargs, context=pickler, warn=True) def to_dict(x): pickler = PsyNetPickler() return pickler.flatten(x) def unserialize(x): # If we don't provide the custom classes directly, jsonpickle tries to find them itself, # and ends up messing up the SQLAlchemy mapper registration system, # producing duplicate mappers for each custom class. # import_local_experiment() # custom_classes = list(get_custom_sql_classes().values()) # return jsonpickle.decode(x, context=unpickler, classes=custom_classes) unpickler = PsyNetUnpickler() return jsonpickle.decode(x, context=unpickler) # return jsonpickle.decode(x, classes=custom_classes) # These classes cannot be reliably pickled by the `jsonpickle` library. # Instead we fall back to Python's built-in pickle library. no_json_classes = [Markup]
[docs] class NoJSONHandler(jsonpickle.handlers.BaseHandler): """ The NoJSONHandler class """
[docs] def flatten(self, obj, state): state["bytes"] = pickle.dumps(obj, 0).decode("latin-1") return state
[docs] def restore(self, state): return pickle.loads(state["bytes"].encode("latin-1"))
for _cls in no_json_classes: jsonpickle.register(_cls, NoJSONHandler, base=True)
[docs] class SQLHandler(jsonpickle.handlers.BaseHandler): """ The SQLHandler class """ def get_primary_keys(self, obj): primary_key_cols = [c.name for c in obj.__class__.__table__.primary_key.columns] return {key: getattr(obj, key) for key in primary_key_cols}
[docs] def flatten(self, obj, state): primary_keys = self.get_primary_keys(obj) if any(key is None for key in primary_keys.values()): raise ValueError( f"Cannot serialize {obj}. It has a `None` value for one of its primary keys: {primary_keys}. " "It might be possible to solve this problem by introducing a `db.session.flush()` call before pickling." ) state["identifiers"] = primary_keys return state
[docs] def restore(self, state): from .experiment import import_local_experiment raise RuntimeError("This should not be called directly") cls_definition = state["py/object"] is_custom_cls = cls_definition.startswith("dallinger_experiment") if is_custom_cls: cls_name = re.sub(".*\\.", "", cls_definition) exp = import_local_experiment() cls = getattr(exp["module"], cls_name) else: cls = loadclass(state["py/object"]) identifiers = state["identifiers"] return cls.query.filter_by(**identifiers).one()
jsonpickle.register(SQLBase, SQLHandler, base=True)
[docs] class DominateHandler(jsonpickle.handlers.BaseHandler): """ The DominateHandler class """
[docs] def flatten(self, obj, state): return str(obj)
jsonpickle.register(dominate.dom_tag.dom_tag, DominateHandler, base=True)