Source code for bioverse.frameworks.pyg

import awkward as ak
import numpy as np
import torch
from torch_geometric.data import Data

from ..framework import Framework


[docs] class PygData(Data): def uncollate(self, y): if "sizes" in self: y = ak.unflatten(y, self.sizes, axis=0) y = ak.Array({"target": y}) return y
[docs] class PygFramework(Framework): @classmethod def collate(cls, X, y=None, attr=[]): if X.resolution == "atom": num_vertices = X.toc["atom"].sum(axis=-1).sum(axis=-1).ravel() else: num_vertices = X.toc["residue"].sum(axis=-1).ravel() num_molecules = len(num_vertices) vertex2batch = torch.arange(num_molecules).repeat_interleave( ak.to_torch(num_vertices) ) if "molecule_graph" in X: offsets = np.insert(np.cumsum(num_vertices), 0, 0)[:-1] X.molecule_graph = X.molecule_graph + offsets return PygData( features=( ak.to_torch(X.vertex_features).float() if "vertex_features" in X else None ), token=(ak.to_torch(X.vertex_token).int() if "vertex_token" in X else None), pos=(ak.to_torch(X.vertex_pos).float() if "vertex_pos" in X else None), edge_index=( ak.to_torch(ak.concatenate(X.molecule_graph, axis=1)).long() if "molecule_graph" in X else None ), # num_nodes=ak.to_torch(len(X[i]["vertices"].features)), mask=(ak.to_torch(X.vertex_mask).bool() if "vertex_mask" in X else None), y=( ak.to_torch( ak.flatten(y["target"], axis=1) if "sizes" in y.fields else y["target"] ).float() if not y is None else None ), vertex2batch=vertex2batch, num_vertices=num_vertices, num_molecules=num_molecules, sizes=(y["sizes"] if not y is None and "sizes" in y.fields else None), **{attr: ak.to_torch(X.__getattr__(attr)) for attr in attr}, )