Coverage for src\model2sas\model.py: 96%

196 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-09 17:12 +0800

1from typing import Literal, Optional, Callable 

2from dataclasses import dataclass 

3from functools import partial 

4from abc import ABC, abstractmethod 

5 

6import torch 

7from torch import Tensor 

8 

9from .global_vars import PRECISION 

10from . import calcfunc 

11from .calcfunc import complex_increase_argument 

12from .utils import log 

13 

14 

15class Model(ABC): 

16 

17 @property 

18 @abstractmethod 

19 def maxq(self) -> float: 

20 """max available |q| value. 

21 """ 

22 

23 @abstractmethod 

24 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

25 """Calculate the complex amplitude of reciprocal position (qx,qy,qz). 

26 Return tensor is the same shape and device as input qx/qy/qz 

27 """ 

28 

29 @abstractmethod 

30 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

31 """Calculate the intensity of reciprocal position (qx,qy,qz). 

32 Return tensor is the same shape and device as input qx/qy/qz 

33 """ 

34 

35 @abstractmethod 

36 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor: 

37 """Calculate the orientation averaged intensity of |q| values. 

38 Return tensor is the same shape and device as input q1d, which 

39 should be 1D tensor. 

40 """ 

41 

42 

43@dataclass 

44class BoundingBox: 

45 xmin: float 

46 ymin: float 

47 zmin: float 

48 xmax: float 

49 ymax: float 

50 zmax: float 

51 

52 def contain(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor: 

53 return (x>=self.xmin) & (x<=self.xmax) & (y>=self.ymin) & (y<=self.ymax) & (z>=self.zmin) & (z<=self.zmax) 

54 

55 @property 

56 def lower(self) -> Tensor: 

57 return torch.tensor((self.xmin, self.ymin, self.zmin), dtype=PRECISION) 

58 

59 @property 

60 def upper(self) -> Tensor: 

61 return torch.tensor((self.xmax, self.ymax, self.zmax), dtype=PRECISION) 

62 

63 

64@dataclass 

65class Grid: 

66 """Equally spaced 3d grid 

67 """ 

68 x1d: Tensor 

69 y1d: Tensor 

70 z1d: Tensor 

71 value3d: Tensor 

72 

73 def __post_init__(self): 

74 self.spacing = torch.abs(self.x1d[1] - self.x1d[0]).item() 

75 self.bounding_box = BoundingBox( 

76 (self.x1d.min() - self.spacing/2).item(), 

77 (self.y1d.min() - self.spacing/2).item(), 

78 (self.z1d.min() - self.spacing/2).item(), 

79 (self.x1d.max() + self.spacing/2).item(), 

80 (self.y1d.max() + self.spacing/2).item(), 

81 (self.z1d.max() + self.spacing/2).item(), 

82 ) 

83 

84 @property 

85 def coord3d(self) -> tuple[Tensor, Tensor, Tensor]: 

86 x, y, z = torch.meshgrid(self.x1d, self.y1d, self.z1d, indexing='ij') 

87 return x, y, z 

88 

89 def interpolate(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor: 

90 d = self.x1d[1] - self.x1d[0] 

91 return calcfunc.trilinear_interp( 

92 x, y, z, self.x1d, self.y1d, self.z1d, self.value3d, d 

93 ) 

94 

95class ReciprocalGrid(Grid): 

96 """Reciprocal grid that is a centrosymmetric 3D grid 

97 so only need upper half (z>=0), other half can be 

98 calculated centrosymmetrically. 

99 And it's basically (0,0) centered in xy plane 

100 """ 

101 def interpolate(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor: 

102 sign = torch.ones_like(z) 

103 sign[z<0] = -1. 

104 return super().interpolate(sign*x, sign*y, sign*z) 

105 

106 @property 

107 def max_radius(self): 

108 return torch.tensor((self.x1d[0], self.x1d[-1], self.y1d[0], self.y1d[-1], self.z1d[-1])).abs().min() 

109 

110 

111class GeoTransforms: 

112 

113 @dataclass 

114 class Record: 

115 type: Literal['translate', 'rotate'] 

116 args: tuple 

117 func_real: Callable 

118 func_reciprocal: Callable[[Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor, Tensor]] 

119 

120 def __init__(self) -> None: 

121 self.records: list[GeoTransforms.Record] = [] 

122 

123 def add_translate(self, vx: float, vy: float, vz: float): 

124 self.records.append(self.Record( 

125 'translate', 

126 (vx, vy, vz), 

127 partial(self._translate_real, vx, vy, vz), 

128 partial(self._translate_reciprocal, vx, vy, vz) 

129 )) 

130 

131 def add_rotate(self, v_axis: tuple[float, float, float], angle: float): 

132 self.records.append(self.Record( 

133 'rotate', 

134 (v_axis, angle), 

135 partial(self._rotate_real, v_axis, angle), 

136 partial(self._rotate_reciprocal, v_axis, angle) 

137 )) 

138 

139 @staticmethod 

140 def _translate_real(vx, vy, vz, x, y, z, value): 

141 pass 

142 

143 @staticmethod 

144 def _translate_reciprocal(vx: float, vy: float, vz: float, qx: Tensor, qy: Tensor, qz: Tensor, complex_argument_addend: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: 

145 added_arg = -(qx*vx + qy*vy + qz*vz) 

146 return qx, qy, qz, complex_argument_addend + added_arg 

147 

148 @staticmethod 

149 def _rotate_real(v_axis, angle, x, y, z, value): 

150 pass 

151 

152 @staticmethod 

153 def _rotate_reciprocal(v_axis: tuple[float, float, float], angle: float, qx: Tensor, qy: Tensor, qz: Tensor, complex_argument_addend: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: 

154 rqx, rqy, rqz = calcfunc.euler_rodrigues_rotate(qx, qy, qz, v_axis, -angle) 

155 return rqx, rqy, rqz, complex_argument_addend 

156 

157 def apply_real(self): 

158 pass 

159 

160 def apply_reciprocal(self, qx: Tensor, qy: Tensor, qz: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: 

161 complex_argument_addend = torch.zeros_like(qx) 

162 for rcd in reversed(self.records): 

163 qx, qy, qz, complex_argument_addend = rcd.func_reciprocal(qx, qy, qz, complex_argument_addend) 

164 return qx, qy, qz, complex_argument_addend 

165 

166 

167 

168class GridModel(Model): 

169 

170 def __init__(self, x1d: Tensor, y1d: Tensor, z1d: Tensor, sld: Tensor, device: Optional[str|torch.device] = None) -> None: 

171 if device is None: 

172 self.device = sld.device 

173 else: 

174 self.device = torch.device(device) 

175 x1d, y1d, z1d, sld = x1d.to(device), y1d.to(device), z1d.to(device), sld.to(device) 

176 self.real_grid = Grid(x1d, y1d, z1d, sld) 

177 

178 self.transforms = GeoTransforms() 

179 self.clear_transforms() # add basic translate record for grid centering 

180 

181 def clear_transforms(self): 

182 self.transforms.records.clear() 

183 # 由于fft默认网格最左下角的格点在(0,0,0)位置,因此基础就要有一定的平移才是真正的模型散射振幅 

184 self.translate( 

185 self.real_grid.bounding_box.xmin + self.real_grid.spacing/2, 

186 self.real_grid.bounding_box.ymin + self.real_grid.spacing/2, 

187 self.real_grid.bounding_box.zmin + self.real_grid.spacing/2, 

188 ) 

189 

190 @property 

191 def sld(self) -> Tensor: 

192 return self.real_grid.value3d 

193 

194 @property 

195 def maxq(self) -> float: 

196 return self.reciprocal_grid.max_radius.item() 

197 

198 @log 

199 def scatter(self, nq: Optional[int] = None, form_factor: bool = True): 

200 # determine 1d grid number in reciprocal space 

201 real_size_max = max(*self.real_grid.x1d.shape, *self.real_grid.y1d.shape, *self.real_grid.z1d.shape) 

202 if form_factor: 

203 if nq is None: 

204 nq = min(600, 10*real_size_max) # in case of using too much resource 

205 else: 

206 nq = max(nq, real_size_max) 

207 else: 

208 nq = real_size_max 

209 

210 s1d = torch.fft.fftfreq(nq, d=self.real_grid.spacing, device=self.device) 

211 s1d = torch.fft.fftshift(s1d) 

212 s1dz = torch.fft.rfftfreq(nq, d=self.real_grid.spacing, device=self.device) 

213 

214 q1d, q1dz = 2*torch.pi*s1d, 2*torch.pi*s1dz 

215 

216 F_half = torch.fft.rfftn(self.real_grid.value3d, s=(nq, nq, nq)) 

217 F_half = torch.fft.fftshift(F_half, dim=(0,1)) 

218 

219 

220 ##### Continuous-density correction ##### 

221 # Correct discrete density to continuous density by  

222 # multiplying box scattering function from a voxel. 

223 # And slso eliminate the intensity difference caused by  

224 # different spacing in real space. 

225 # ATTENTION! 

226 # Result from sphere shows applying continuous-density 

227 # correction gives worse result then before, which shows 

228 # larger deviation from -4 slope line. 

229 d = self.real_grid.spacing 

230 sinc = lambda t: torch.nan_to_num(torch.sin(t)/t, nan=1.) 

231 sinc1d, sinc1dz = sinc(q1d*d/2), sinc(q1dz*d/2) 

232 box_scatt = torch.einsum('i,j,k->ijk', sinc1d, sinc1d, sinc1dz) 

233 F_half = F_half * d**3 * box_scatt 

234 

235 self.reciprocal_grid = ReciprocalGrid(q1d, q1d, q1dz, F_half) 

236 

237 

238 def translate(self, vx: float, vy: float, vz: float): 

239 self.transforms.add_translate(vx, vy, vz) 

240 

241 def rotate(self, v_axis: tuple[float, float, float], angle: float): 

242 self.transforms.add_rotate(v_axis, angle) 

243 

244 @log 

245 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

246 input_device = qx.device 

247 qx, qy, qz = qx.to(self.device), qy.to(self.device), qz.to(self.device) 

248 qx, qy, qz, complex_argument_addend = self.transforms.apply_reciprocal(qx, qy, qz) 

249 F = self.reciprocal_grid.interpolate(qx, qy, qz) 

250 F = complex_increase_argument(F, complex_argument_addend) 

251 return F.to(input_device) 

252 

253 @log 

254 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

255 F = self.amplitude(qx, qy, qz) 

256 return F.real**2 + F.imag**2 

257 

258 @log 

259 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor: 

260 q1d_effective = q1d[q1d<=self.maxq] 

261 N = torch.round(q1d_effective/q1d_effective[0]) + offset 

262 

263 qx, qy, qz = calcfunc.multiple_spherical_sampling(q1d_effective, N) 

264 

265 Iall = self.intensity(qx, qy, qz) 

266 I = torch.zeros_like(q1d_effective) 

267 begin = 0 

268 for i, n in enumerate(N): 

269 n = n.int().item() 

270 I[i] = Iall[begin:begin+n].mean() 

271 begin += n 

272 

273 I1d = torch.zeros_like(q1d) 

274 I1d[q1d<=self.maxq] = I 

275 I1d[q1d>self.maxq] = torch.nan 

276 return I1d 

277 

278 

279class AssemblyModel(Model): 

280 

281 def __init__(self, *grid_models: 'GridModel|AssemblyModel') -> None: 

282 self.components = grid_models 

283 

284 @property 

285 def maxq(self) -> float: 

286 return min(*[model.maxq for model in self.components]) 

287 

288 @log 

289 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

290 F = torch.complex(torch.zeros_like(qx), torch.zeros_like(qx)) 

291 for model in self.components: 

292 F += model.amplitude(qx, qy, qz) 

293 return F 

294 

295 @log 

296 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor: 

297 F = self.amplitude(qx, qy, qz) 

298 return F.real**2 + F.imag**2 

299 

300 @log 

301 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor: 

302 q1d_effective = q1d[q1d<=self.maxq] 

303 N = torch.round(q1d_effective/q1d_effective[0]) + offset 

304 

305 qx, qy, qz = calcfunc.multiple_spherical_sampling(q1d_effective, N) 

306 

307 Iall = self.intensity(qx, qy, qz) 

308 I = torch.zeros_like(q1d_effective) 

309 begin = 0 

310 for i, n in enumerate(N): 

311 n = n.int().item() 

312 I[i] = Iall[begin:begin+n].mean() 

313 begin += n 

314 

315 I1d = torch.zeros_like(q1d) 

316 I1d[q1d<=self.maxq] = I 

317 I1d[q1d>self.maxq] = torch.nan 

318 return I1d