Coverage for src\model2sas\readfile.py: 100%
96 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:14 +0800
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:14 +0800
1from abc import ABC, abstractmethod
2from typing import Optional, Literal
3from pathlib import Path
4import sys
6import torch
7from torch import Tensor
8from stl import mesh
9import Bio.PDB
10import periodictable as pt
11import numpy as np
13from .global_vars import PRECISION
14from .model import GridModel, BoundingBox
15from . import calcfunc
16from .utils import log
19@log
20def meshgrid(bounding_box: BoundingBox, n_long: int = 50, spacing: Optional[float] = None, device: Optional[str|torch.device] = None, against_edges: bool = False) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
21 bbox = bounding_box
22 lx, ly, lz = bbox.xmax - bbox.xmin, bbox.ymax - bbox.ymin, bbox.zmax - bbox.zmin
23 lmin, lmax = min(lx, ly, lz), max(lx, ly, lz)
25 if spacing is None:
26 spacing = lmax / n_long
27 spacing = min(spacing, lmin/10) # in case of too few points in short edge
29 eps = spacing * 1e-3
30 if against_edges:
31 x1d = torch.arange(bbox.xmin, bbox.xmax+spacing+eps, spacing, dtype=PRECISION, device=device)
32 y1d = torch.arange(bbox.ymin, bbox.ymax+spacing+eps, spacing, dtype=PRECISION, device=device)
33 z1d = torch.arange(bbox.zmin, bbox.zmax+spacing+eps, spacing, dtype=PRECISION, device=device)
34 else:
35 x1d = torch.arange(bbox.xmin+spacing/2, bbox.xmax+spacing/2+eps, spacing, dtype=PRECISION, device=device)
36 y1d = torch.arange(bbox.ymin+spacing/2, bbox.ymax+spacing/2+eps, spacing, dtype=PRECISION, device=device)
37 z1d = torch.arange(bbox.zmin+spacing/2, bbox.zmax+spacing/2+eps, spacing, dtype=PRECISION, device=device)
38 x, y, z = torch.meshgrid(x1d, y1d, z1d, indexing='ij')
39 return x1d, y1d, z1d, x, y, z
44class AbstractMathModel(ABC):
45 """Parameter name obeys python var name protocal.
46 Must have:
47 coord: Literal['car', 'sph', 'cyl']
48 Must avoid:
49 filename, filename_or_class, sld_value,
50 centering, n_long, spacing, device
51 """
52 @abstractmethod
53 def __init__(self) -> None:
54 """Define coord and other params here.
55 """
57 def update_params(self, **kwargs) -> None:
58 self.__dict__.update(**kwargs)
60 @abstractmethod
61 def bounding_box(self) -> tuple[float, float, float, float, float, float]:
62 """re-generate boundary for every method call
63 in case that params are altered in software.
64 return coordinates in Cartesian coordinates.
66 Returns:
67 tuple[float, float, float, float, float, float]: xmin, ymin, zmin, xmax, ymax, zmax
68 """
70 @abstractmethod
71 def sld(self, u: Tensor, v: Tensor, w: Tensor) -> Tensor:
72 """Calculate sld values of certain coordinates.
73 u, v, w means:
74 x, y, z if self.coord=='car';
75 r, theta, phi if self.coord=='sph';
76 rho, theta, z if self.coord=='cyl';
78 Args:
79 u (Tensor): 1st coord
80 v (Tensor): 2nd coord
81 w (Tensor): 3rd coord
83 Returns:
84 Tensor: sld values of each coordinates
85 """
88def import_mathmodel_class(filename: str | Path):
89 filename = Path(filename)
90 sys.path.append(str(filename.absolute().parent))
91 module = __import__(filename.stem)
92 mathmodel_class = module.MathModel
93 return mathmodel_class
96@log
97def read_math(
98 filename_or_mathmodel: str|AbstractMathModel,
99 n_long: int = 50,
100 spacing: Optional[float] = None,
101 device: Optional[str|torch.device] = None,
102 **mathmodel_params,
103 ) -> GridModel:
104 if isinstance(filename_or_mathmodel, str):
105 filename = filename_or_mathmodel
106 MathModel = import_mathmodel_class(filename)
107 mathmodel = MathModel()
108 else:
109 mathmodel = filename_or_mathmodel
111 mathmodel.update_params(**mathmodel_params)
113 bounding_box = BoundingBox(*mathmodel.bounding_box())
114 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device)
116 u, v, w = calcfunc.convert_coord(x, y, z, 'car', mathmodel.coord)
117 sld = mathmodel.sld(u, v, w)
119 return GridModel(x1d, y1d, z1d, sld, device=device)
122@log
123def read_stl(
124 filename: str,
125 sld_value: int = 1,
126 centering: bool = False,
127 n_long: int = 50,
128 spacing: Optional[float] = None,
129 device: Optional[str|torch.device] = None
130 ) -> GridModel:
132 stlmesh = mesh.Mesh.from_file(filename)
134 if centering:# move model center to (0,0,0)
135 center = stlmesh.get_mass_properties()[1]
136 stlmesh.translate(-center)
138 vec = stlmesh.vectors
139 vec = vec.reshape((vec.shape[0]*vec.shape[1], vec.shape[2]))
140 bboxmin, bboxmax = vec.min(axis=0), vec.max(axis=0)
141 bounding_box = BoundingBox(*bboxmin, *bboxmax)
143 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device)
145 # check inside model by ray intersect
146 points = torch.stack((x, y, z), dim=-1)
147 ray = torch.rand(3, dtype=PRECISION, device=device) - 0.5
148 triangles = torch.from_numpy(stlmesh.vectors.copy()).to(PRECISION).to(device)
149 intersect_count = calcfunc.moller_trumbore_intersect_count(points, ray, triangles)
151 index = intersect_count % 2 # 1 is in, 0 is out
152 sld = sld_value * index
153 sld = sld.reshape(x.shape)
155 return GridModel(x1d, y1d, z1d, sld, device=device)
158@log
159def read_pdb(
160 filename: str,
161 probe: Literal['xray', 'neutron'] = 'xray',
162 wavelength: float = 1.54,
163 n_long: int = 50,
164 spacing: Optional[float] = None,
165 device: Optional[str|torch.device] = None
166 ) -> GridModel:
167 """与stl,math的区别在于pdb模型中格点处就代表一个原子,
168 所以meshgrid要贴边生成。
169 """
170 # slowest part
171 pdbparser = Bio.PDB.PDBParser(QUIET=True) # suppress PDBConstructionWarning
172 pdb_structure = pdbparser.get_structure(Path(filename).stem, filename)
174 if probe == 'neutron':
175 atom_f_func = lambda pt_element: pt_element.neutron.b_c
176 else:
177 atom_f_func = lambda pt_element: pt_element.xray.scattering_factors(wavelength=wavelength)[0]
179 # second slowest, but bottlenetck should be pdb_structure.get_atoms()?
180 atom_f = torch.tensor(
181 [atom_f_func(pt.elements.symbol(atom.element.capitalize())) for atom in pdb_structure.get_atoms()],
182 dtype=PRECISION, device=device
183 )
184 atom_coord = torch.from_numpy(
185 np.stack([atom.coord for atom in pdb_structure.get_atoms()], axis=0)
186 ).to(PRECISION).to(device)
188 bounding_box = BoundingBox(
189 *(atom_coord.min(dim=0).values).tolist(),
190 *(atom_coord.max(dim=0).values).tolist(),
191 )
192 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device, against_edges=True)
194 actual_spacing = x1d[1] - x1d[0]
195 index = (atom_coord - bounding_box.lower.to(device)) / actual_spacing
196 index = index.round().to(torch.int64)
197 sld = torch.zeros_like(x)
198 sld[index[:,0], index[:,1], index[:,2]] += atom_f # in case that multiple atoms in one voxel
200 return GridModel(x1d, y1d, z1d, sld, device=device)