Source code for bioverse.benchmark

import inspect
import os
from abc import ABC
from functools import cached_property, partial
from pathlib import Path
from typing import Iterable, Tuple

import awkward as ak
from typing_extensions import Self

from .data import Batch
from .dataset import Dataset
from .framework import Framework
from .metric import Metric, Result
from .sampler import Sampler
from .task import Task
from .transform import Transform
from .utilities import config, load, parallelize
from .virtual import VirtualBatch


[docs] class Benchmark(ABC): dataset: Dataset sampler: Sampler task: Task metric: Metric
[docs] def __init__( self, root: Path | str = config.benchmarks_path, version: int = 0, partition: str = "main", n_jobs: int | None = None, ) -> None: """ Parameters ---------- root : Path or str, optional Root directory for benchmark data storage, defaults to config.benchmarks_path version : int, optional Version number of the benchmark, defaults to 0 partition : str, optional Partition name for the benchmark data, defaults to "main" n_jobs : int or None, optional Number of parallel jobs to run. If None, uses all available cores Notes ----- The benchmark class requires the following class attributes to be defined: - dataset: Dataset class, instance, or (class, kwargs) tuple - sampler: Sampler class, instance, or (class, kwargs) tuple - task: Task class, instance, or (class, kwargs) tuple - metric: Metric class, instance, or (class, kwargs) tuple """ self.root = Path(root) / f"{self.name}v{version}" / partition os.makedirs(self.root, exist_ok=True) self.partition = partition self.n_jobs = n_jobs # initialize components from config, class, or instance for attr in ["dataset", "sampler", "task", "metric"]: kwargs = getattr(self, attr) if isinstance(kwargs, tuple): cls, kwargs = kwargs setattr(self, attr, cls(**kwargs)) elif inspect.isclass(kwargs): setattr(self, attr, kwargs()) else: setattr(self, attr, kwargs)
@property def name(self) -> str: return self.__class__.__name__ @property def leaderboard(self) -> Result: # temporary fix for leaderboard database file_path = os.path.realpath(os.path.dirname(__file__)) return Result(load(f"{file_path}/leaderboards/{self.name}.yml", [])) def apply(self, *transforms: Transform) -> Self: self.dataset.apply(*transforms) # initialize loaders (transforms can change splits, so we do it here) for split in self.dataset.split.attrs["names"]: setattr(self, f"{split}_loader", partial(self.wrap_loader, split=split)) return self def wrap_loader(self, *args, **kwargs): if not "random_seed" in kwargs or kwargs["random_seed"] is None: kwargs["random_seed"] = config.seed class Loader: def __init__(self, loader): self.loader = loader self.args = args self.kwargs = kwargs def __iter__(self): self.kwargs["random_seed"] += 1 return self.loader(*self.args, **self.kwargs) return Loader(self.loader) def loader( self, split: str, batch_size: int = 1, batch_on: str = "molecules", shuffle: bool = False, drop_last: bool = False, random_seed: int = config.seed, world_size: int = 1, rank: int = 0, progress: bool = True, framework: Framework | None = None, scratch: bool = False, ) -> Iterable[Tuple[Tuple[ak.Array, ...], Batch | None]]: if scratch: self.dataset.move_to_scratch() # sample the index index = self.sampler.sample( dataset=self.dataset, split=split, batch_size=batch_size, batch_on=batch_on, shuffle=shuffle, drop_last=drop_last, random_seed=random_seed, world_size=world_size, rank=rank, ) task, vbatch, assets = self.task, self.dataset.virtual(), self.dataset.assets def worker(batch_index): Xy = task(vbatch, assets, batch_index) data = framework.collate(*Xy) if framework else None return Xy, data return parallelize( worker, index, description="Loader", progress=progress and rank == 0, total=len(index), max_workers=10, ) def update(self, y_true: ak.Array, y_pred: ak.Array) -> None: y_true = self.dataset.transform.inverse_transform(y_true) y_pred = self.dataset.transform.inverse_transform(y_pred) self.metric.update(y_true, y_pred) def result(self, *args, **kwargs) -> Result: return self.leaderboard + self.metric.result(*args, **kwargs)