Source code for psynet.process

import datetime
import inspect
import threading
import time

import dallinger.db
from dallinger import db
from dallinger.db import redis_conn
from jsonpickle.util import importable_name
from rq import Queue
from rq.job import Job
from sqlalchemy import (
    Boolean,
    Column,
    DateTime,
    Float,
    ForeignKey,
    Integer,
    String,
    event,
)
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import deferred, relationship
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_exponential

from .data import SQLBase, SQLMixin, register_table
from .db import with_transaction
from .field import PythonDict, PythonObject
from .utils import call_function, classproperty, get_logger

logger = get_logger()


[docs] @register_table class AsyncProcess(SQLBase, SQLMixin): __tablename__ = "process" __extra_vars__ = SQLMixin.__extra_vars__.copy() label = Column(String) function = Column(PythonObject) arguments = deferred(Column(PythonDict)) pending = Column(Boolean) finished = Column(Boolean, default=False) time_started = Column(DateTime) time_finished = Column(DateTime) time_taken = Column(Float) _unique_key = Column(PythonDict, unique=True) participant_id = Column(Integer, ForeignKey("participant.id"), index=True) participant = relationship( "psynet.participant.Participant", backref="async_processes" ) trial_maker_id = Column(String, index=True) network_id = Column(Integer, ForeignKey("network.id"), index=True) network = relationship("TrialNetwork", back_populates="async_processes") node_id = Column(Integer, ForeignKey("node.id"), index=True) node = relationship("TrialNode", back_populates="async_processes") trial_id = Column(Integer, ForeignKey("info.id"), index=True) trial = relationship("psynet.trial.main.Trial", back_populates="async_processes") response_id = Column(Integer, ForeignKey("response.id"), index=True) response = relationship( "psynet.timeline.Response", back_populates="async_processes" ) asset_id = Column(Integer, ForeignKey("asset.id"), index=True) asset = relationship("Asset", back_populates="async_processes") errors = relationship("ErrorRecord") launch_queue = [] def add_to_launch_queue(self): self.launch_queue.append(self.get_launch_spec()) def get_launch_spec(self) -> dict: db.session.flush([self]) return { "obj": self, "class": self.__class__, "id": self.id, } @classmethod def launch_all(cls): while cls.launch_queue: process = cls.launch_queue.pop(0) assert process["obj"].id is not None logger.info("Launching async process %s...", process["id"]) process["class"].launch(process) def __init__( self, function, arguments=None, trial=None, response=None, participant=None, node=None, network=None, asset=None, label=None, unique=False, ): if label is None: label = function.__name__ if arguments is None: arguments = {} db.session.flush() if inspect.ismethod(function): method_name = function.__name__ method_caller = function.__self__ function = getattr(method_caller.__class__, method_name) arguments["self"] = method_caller self.check_function(function) self.label = label self.function = function self.arguments = arguments self.asset = asset if asset: self.asset_id = asset.id self.participant = participant if participant: self.participant_id = participant.id self.network = network if network: self.network_id = network.id self.node = node if node: self.node_id = node.id self.trial = trial if trial: self.trial_id = trial.id self.response = response if response: self.response_id = response.id self.infer_participant() self.infer_trial_maker_id() self.pending = True if unique: if isinstance(unique, bool): self._unique_key = { "label": label, "function": function, "arguments": arguments, } else: self._unique_key = unique db.session.add(self) self.add_to_launch_queue() def check_function(self, function): from .serialize import serialize, unserialize assert callable(function) if "<locals>" in importable_name(function): raise ValueError( "You cannot use a function defined within another function " "in an async process." ) if unserialize(serialize(function)) is None: raise ValueError( "The provided function could not be serialized. Make sure that the function is defined at the module " "or class level, rather than being a lambda function or a temporary function defined within " "another function." ) if inspect.ismethod(function): raise ValueError( "You cannot pass an instance method to an AsyncProcess. ", "Try writing a class method or a static method instead.", ) def log_time_started(self): self.time_started = datetime.datetime.now() def log_time_finished(self): self.time_finished = datetime.datetime.now() def infer_participant(self): if self.participant is None: for obj in [self.asset, self.trial, self.node, self.network]: if obj and hasattr(obj, "participant") and obj.participant: self.participant = obj.participant break # For safety... if self.participant: self.participant_id = self.participant.id def infer_trial_maker_id(self): for obj in [self.trial, self.node, self.network]: if obj and obj.trial_maker_id: self.trial_maker_id = obj.trial_maker_id return @property def failure_cascade(self): """ These are the objects that will be failed if the process fails. Ultimately we might want to add more objects to this list, for example participants, assets, and networks, but currently we're not confident that PsyNet supports failing those objects in that kind of way. """ candidates = [self.trial, self.node] return [lambda obj=obj: [obj] for obj in candidates if obj is not None] @classmethod def launch(cls, process: dict): raise NotImplementedError @classproperty def redis_queue(cls): return Queue("default", connection=redis_conn) @classmethod def call_function_with_logger(cls, process_id): cls.call_function(process_id) @classmethod @retry( retry=retry_if_exception_type(NoResultFound), wait=wait_exponential(multiplier=0.1, min=0.01), stop=stop_after_delay(4), ) # The process gets launched when SQLAlchemy's after_commit event is triggered. This tells us when the COMMIT # has been issued to the database, but it does not guarantee that the commit has finished execution. # This is why we add some retry logic to ensure that the process is available in the database before continuing. def get_process(cls, process_id: int): return ( AsyncProcess.query.filter_by(id=process_id) .with_for_update(of=AsyncProcess) .populate_existing() .one() )
[docs] @classmethod @with_transaction def call_function(cls, process_id): """ Calls the defining function of a given process """ # cls.log(f"Calling function for process_id: {process_id}") print("\n") logger.info(f"Calling function for process_id {process_id}...") process = None try: from psynet.experiment import get_experiment experiment = get_experiment() process = cls.get_process(process_id) function = process.function arguments = cls.preprocess_args(process.arguments) timer = time.monotonic() process.time_started = datetime.datetime.now() db.session.commit() call_function(function, **arguments) process.time_finished = datetime.datetime.now() process.time_taken = time.monotonic() - timer process.pending = False process.finished = True from psynet.trial.main import Trial if "self" in arguments and isinstance(arguments["self"], Trial): arguments["self"].check_if_can_mark_as_finalized() except Exception as err: if not isinstance(err, experiment.HandledError): experiment.handle_error(err, process=process) if process: process.pending = False process.fail(f"Exception in asynchronous process: {repr(err)}") finally: db.session.commit()
[docs] @classmethod def preprocess_args(cls, arguments): """ Preprocesses the arguments that are passed to the process's function. """ return {key: cls.preprocess_arg(value) for key, value in arguments.items()}
@classmethod def preprocess_arg(cls, arg): if isinstance( arg, dallinger.db.Base ): # Tests if the object is an SQLAlchemy object arg = db.session.merge(arg) # Reattaches the object to the database session db.session.refresh(arg) return arg @classmethod def log(cls, msg): raise NotImplementedError @classmethod def log_to_stdout(cls, msg): print(msg) @classmethod def log_to_redis(cls, msg): cls.redis_queue.enqueue_call( func=logger.info, args=(), kwargs=dict(msg=msg), timeout=1e10, at_front=True )
@event.listens_for(db.session, "after_commit") def receive_after_commit(session): AsyncProcess.launch_all()
[docs] class LocalAsyncProcess(AsyncProcess): @classmethod def launch(cls, process: dict): thr = threading.Thread( target=cls.thread_function, kwargs={"process_id": process["id"]} ) thr.start() @classmethod def thread_function(cls, process_id): try: cls.call_function_with_logger(process_id) finally: db.session.commit() db.session.close() @classmethod def call_function_with_logger(cls, process_id): cls.call_function(process_id)
# log = io.StringIO() # try: # with contextlib.redirect_stdout(log): # # yield # cls.call_function(process_id) # finally: # cls.log_to_redis(log.getvalue()) # with cls.log_output(): # cls.call_function(process_id) # @classmethod # @contextlib.contextmanager # def log_output(cls): # log = io.StringIO() # try: # with contextlib.redirect_stdout(log): # yield # finally: # cls.log_to_redis(log.getvalue()) # @classmethod # def log(cls, msg): # cls.log_to_redis(msg)
[docs] class WorkerAsyncProcess(AsyncProcess): redis_job_id = Column(String) timeout = Column(Float) # note -- currently only applies to non-local proceses timeout_scheduled_for = Column(DateTime) cancelled = Column(Boolean, default=False) def get_launch_spec(self) -> dict: spec = super().get_launch_spec() spec["timeout"] = self.timeout return spec def __init__( self, function, arguments=None, trial=None, participant=None, node=None, network=None, asset=None, label=None, unique=False, timeout=None, # <-- new argument for this class ): self.timeout = timeout if timeout: self.timeout_scheduled_for = datetime.datetime.now() + datetime.timedelta( seconds=timeout ) super().__init__( function, arguments, trial=trial, participant=participant, node=node, network=network, asset=asset, label=label, unique=unique, ) @classmethod def launch(cls, process: dict): # Previously we took the id of the enqueue_call and saved that in Process.redis_job_id, # but this is not possible now that the Process object is not accessible. cls.redis_queue.enqueue_call( func=cls.call_function_with_logger, args=(), kwargs=dict(process_id=process["id"]), timeout=process["timeout"], ) @classmethod def check_timeouts(cls): processes = cls.query.filter( cls.pending, ~cls.failed, cls.timeout != None, # noqa -- this is special SQLAlchemy syntax cls.timeout_scheduled_for < datetime.datetime.now(), ).all() for p in processes: p.fail( "Asynchronous process timed out", ) db.session.commit() @property def redis_job(self): return Job.fetch(self.redis_job_id, connection=redis_conn) def cancel(self): self.cancelled = True self.pending = False self.fail("Cancelled asynchronous process") self.job.cancel() db.session.commit()
# @classmethod # def log(cls, msg): # cls.log_to_stdout(msg)