Coverage for src\model2sas\readfile.py: 100%

96 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:14 +0800

1from abc import ABC, abstractmethod 

2from typing import Optional, Literal 

3from pathlib import Path 

4import sys 

5 

6import torch 

7from torch import Tensor 

8from stl import mesh 

9import Bio.PDB 

10import periodictable as pt 

11import numpy as np 

12 

13from .global_vars import PRECISION 

14from .model import GridModel, BoundingBox 

15from . import calcfunc 

16from .utils import log 

17 

18 

19@log 

20def meshgrid(bounding_box: BoundingBox, n_long: int = 50, spacing: Optional[float] = None, device: Optional[str|torch.device] = None, against_edges: bool = False) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 

21 bbox = bounding_box 

22 lx, ly, lz = bbox.xmax - bbox.xmin, bbox.ymax - bbox.ymin, bbox.zmax - bbox.zmin 

23 lmin, lmax = min(lx, ly, lz), max(lx, ly, lz) 

24 

25 if spacing is None: 

26 spacing = lmax / n_long 

27 spacing = min(spacing, lmin/10) # in case of too few points in short edge 

28 

29 eps = spacing * 1e-3 

30 if against_edges: 

31 x1d = torch.arange(bbox.xmin, bbox.xmax+spacing+eps, spacing, dtype=PRECISION, device=device) 

32 y1d = torch.arange(bbox.ymin, bbox.ymax+spacing+eps, spacing, dtype=PRECISION, device=device) 

33 z1d = torch.arange(bbox.zmin, bbox.zmax+spacing+eps, spacing, dtype=PRECISION, device=device) 

34 else: 

35 x1d = torch.arange(bbox.xmin+spacing/2, bbox.xmax+spacing/2+eps, spacing, dtype=PRECISION, device=device) 

36 y1d = torch.arange(bbox.ymin+spacing/2, bbox.ymax+spacing/2+eps, spacing, dtype=PRECISION, device=device) 

37 z1d = torch.arange(bbox.zmin+spacing/2, bbox.zmax+spacing/2+eps, spacing, dtype=PRECISION, device=device) 

38 x, y, z = torch.meshgrid(x1d, y1d, z1d, indexing='ij') 

39 return x1d, y1d, z1d, x, y, z 

40 

41 

42 

43 

44class AbstractMathModel(ABC): 

45 """Parameter name obeys python var name protocal. 

46 Must have: 

47 coord: Literal['car', 'sph', 'cyl'] 

48 Must avoid:  

49 filename, filename_or_class, sld_value,  

50 centering, n_long, spacing, device 

51 """ 

52 @abstractmethod 

53 def __init__(self) -> None: 

54 """Define coord and other params here. 

55 """ 

56 

57 def update_params(self, **kwargs) -> None: 

58 self.__dict__.update(**kwargs) 

59 

60 @abstractmethod 

61 def bounding_box(self) -> tuple[float, float, float, float, float, float]: 

62 """re-generate boundary for every method call 

63 in case that params are altered in software. 

64 return coordinates in Cartesian coordinates. 

65 

66 Returns: 

67 tuple[float, float, float, float, float, float]: xmin, ymin, zmin, xmax, ymax, zmax 

68 """ 

69 

70 @abstractmethod 

71 def sld(self, u: Tensor, v: Tensor, w: Tensor) -> Tensor: 

72 """Calculate sld values of certain coordinates. 

73 u, v, w means: 

74 x, y, z if self.coord=='car'; 

75 r, theta, phi if self.coord=='sph'; 

76 rho, theta, z if self.coord=='cyl'; 

77 

78 Args: 

79 u (Tensor): 1st coord 

80 v (Tensor): 2nd coord 

81 w (Tensor): 3rd coord 

82 

83 Returns: 

84 Tensor: sld values of each coordinates 

85 """ 

86 

87 

88def import_mathmodel_class(filename: str | Path): 

89 filename = Path(filename) 

90 sys.path.append(str(filename.absolute().parent)) 

91 module = __import__(filename.stem) 

92 mathmodel_class = module.MathModel 

93 return mathmodel_class 

94 

95 

96@log 

97def read_math( 

98 filename_or_mathmodel: str|AbstractMathModel, 

99 n_long: int = 50, 

100 spacing: Optional[float] = None, 

101 device: Optional[str|torch.device] = None, 

102 **mathmodel_params, 

103 ) -> GridModel: 

104 if isinstance(filename_or_mathmodel, str): 

105 filename = filename_or_mathmodel 

106 MathModel = import_mathmodel_class(filename) 

107 mathmodel = MathModel() 

108 else: 

109 mathmodel = filename_or_mathmodel 

110 

111 mathmodel.update_params(**mathmodel_params) 

112 

113 bounding_box = BoundingBox(*mathmodel.bounding_box()) 

114 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device) 

115 

116 u, v, w = calcfunc.convert_coord(x, y, z, 'car', mathmodel.coord) 

117 sld = mathmodel.sld(u, v, w) 

118 

119 return GridModel(x1d, y1d, z1d, sld, device=device) 

120 

121 

122@log 

123def read_stl( 

124 filename: str, 

125 sld_value: int = 1, 

126 centering: bool = False, 

127 n_long: int = 50, 

128 spacing: Optional[float] = None, 

129 device: Optional[str|torch.device] = None 

130 ) -> GridModel: 

131 

132 stlmesh = mesh.Mesh.from_file(filename) 

133 

134 if centering:# move model center to (0,0,0) 

135 center = stlmesh.get_mass_properties()[1] 

136 stlmesh.translate(-center) 

137 

138 vec = stlmesh.vectors 

139 vec = vec.reshape((vec.shape[0]*vec.shape[1], vec.shape[2])) 

140 bboxmin, bboxmax = vec.min(axis=0), vec.max(axis=0) 

141 bounding_box = BoundingBox(*bboxmin, *bboxmax) 

142 

143 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device) 

144 

145 # check inside model by ray intersect 

146 points = torch.stack((x, y, z), dim=-1) 

147 ray = torch.rand(3, dtype=PRECISION, device=device) - 0.5 

148 triangles = torch.from_numpy(stlmesh.vectors.copy()).to(PRECISION).to(device) 

149 intersect_count = calcfunc.moller_trumbore_intersect_count(points, ray, triangles) 

150 

151 index = intersect_count % 2 # 1 is in, 0 is out 

152 sld = sld_value * index 

153 sld = sld.reshape(x.shape) 

154 

155 return GridModel(x1d, y1d, z1d, sld, device=device) 

156 

157 

158@log 

159def read_pdb( 

160 filename: str, 

161 probe: Literal['xray', 'neutron'] = 'xray', 

162 wavelength: float = 1.54, 

163 n_long: int = 50, 

164 spacing: Optional[float] = None, 

165 device: Optional[str|torch.device] = None 

166 ) -> GridModel: 

167 """与stl,math的区别在于pdb模型中格点处就代表一个原子, 

168 所以meshgrid要贴边生成。 

169 """ 

170 # slowest part 

171 pdbparser = Bio.PDB.PDBParser(QUIET=True) # suppress PDBConstructionWarning 

172 pdb_structure = pdbparser.get_structure(Path(filename).stem, filename) 

173 

174 if probe == 'neutron': 

175 atom_f_func = lambda pt_element: pt_element.neutron.b_c 

176 else: 

177 atom_f_func = lambda pt_element: pt_element.xray.scattering_factors(wavelength=wavelength)[0] 

178 

179 # second slowest, but bottlenetck should be pdb_structure.get_atoms()? 

180 atom_f = torch.tensor( 

181 [atom_f_func(pt.elements.symbol(atom.element.capitalize())) for atom in pdb_structure.get_atoms()], 

182 dtype=PRECISION, device=device 

183 ) 

184 atom_coord = torch.from_numpy( 

185 np.stack([atom.coord for atom in pdb_structure.get_atoms()], axis=0) 

186 ).to(PRECISION).to(device) 

187 

188 bounding_box = BoundingBox( 

189 *(atom_coord.min(dim=0).values).tolist(), 

190 *(atom_coord.max(dim=0).values).tolist(), 

191 ) 

192 x1d, y1d, z1d, x, y, z = meshgrid(bounding_box, n_long=n_long, spacing=spacing, device=device, against_edges=True) 

193 

194 actual_spacing = x1d[1] - x1d[0] 

195 index = (atom_coord - bounding_box.lower.to(device)) / actual_spacing 

196 index = index.round().to(torch.int64) 

197 sld = torch.zeros_like(x) 

198 sld[index[:,0], index[:,1], index[:,2]] += atom_f # in case that multiple atoms in one voxel 

199 

200 return GridModel(x1d, y1d, z1d, sld, device=device)