Coverage for src\model2sas\model.py: 96%
196 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
1from typing import Literal, Optional, Callable
2from dataclasses import dataclass
3from functools import partial
4from abc import ABC, abstractmethod
6import torch
7from torch import Tensor
9from .global_vars import PRECISION
10from . import calcfunc
11from .calcfunc import complex_increase_argument
12from .utils import log
15class Model(ABC):
17 @property
18 @abstractmethod
19 def maxq(self) -> float:
20 """max available |q| value.
21 """
23 @abstractmethod
24 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
25 """Calculate the complex amplitude of reciprocal position (qx,qy,qz).
26 Return tensor is the same shape and device as input qx/qy/qz
27 """
29 @abstractmethod
30 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
31 """Calculate the intensity of reciprocal position (qx,qy,qz).
32 Return tensor is the same shape and device as input qx/qy/qz
33 """
35 @abstractmethod
36 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor:
37 """Calculate the orientation averaged intensity of |q| values.
38 Return tensor is the same shape and device as input q1d, which
39 should be 1D tensor.
40 """
43@dataclass
44class BoundingBox:
45 xmin: float
46 ymin: float
47 zmin: float
48 xmax: float
49 ymax: float
50 zmax: float
52 def contain(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor:
53 return (x>=self.xmin) & (x<=self.xmax) & (y>=self.ymin) & (y<=self.ymax) & (z>=self.zmin) & (z<=self.zmax)
55 @property
56 def lower(self) -> Tensor:
57 return torch.tensor((self.xmin, self.ymin, self.zmin), dtype=PRECISION)
59 @property
60 def upper(self) -> Tensor:
61 return torch.tensor((self.xmax, self.ymax, self.zmax), dtype=PRECISION)
64@dataclass
65class Grid:
66 """Equally spaced 3d grid
67 """
68 x1d: Tensor
69 y1d: Tensor
70 z1d: Tensor
71 value3d: Tensor
73 def __post_init__(self):
74 self.spacing = torch.abs(self.x1d[1] - self.x1d[0]).item()
75 self.bounding_box = BoundingBox(
76 (self.x1d.min() - self.spacing/2).item(),
77 (self.y1d.min() - self.spacing/2).item(),
78 (self.z1d.min() - self.spacing/2).item(),
79 (self.x1d.max() + self.spacing/2).item(),
80 (self.y1d.max() + self.spacing/2).item(),
81 (self.z1d.max() + self.spacing/2).item(),
82 )
84 @property
85 def coord3d(self) -> tuple[Tensor, Tensor, Tensor]:
86 x, y, z = torch.meshgrid(self.x1d, self.y1d, self.z1d, indexing='ij')
87 return x, y, z
89 def interpolate(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor:
90 d = self.x1d[1] - self.x1d[0]
91 return calcfunc.trilinear_interp(
92 x, y, z, self.x1d, self.y1d, self.z1d, self.value3d, d
93 )
95class ReciprocalGrid(Grid):
96 """Reciprocal grid that is a centrosymmetric 3D grid
97 so only need upper half (z>=0), other half can be
98 calculated centrosymmetrically.
99 And it's basically (0,0) centered in xy plane
100 """
101 def interpolate(self, x: Tensor, y: Tensor, z: Tensor) -> Tensor:
102 sign = torch.ones_like(z)
103 sign[z<0] = -1.
104 return super().interpolate(sign*x, sign*y, sign*z)
106 @property
107 def max_radius(self):
108 return torch.tensor((self.x1d[0], self.x1d[-1], self.y1d[0], self.y1d[-1], self.z1d[-1])).abs().min()
111class GeoTransforms:
113 @dataclass
114 class Record:
115 type: Literal['translate', 'rotate']
116 args: tuple
117 func_real: Callable
118 func_reciprocal: Callable[[Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor, Tensor]]
120 def __init__(self) -> None:
121 self.records: list[GeoTransforms.Record] = []
123 def add_translate(self, vx: float, vy: float, vz: float):
124 self.records.append(self.Record(
125 'translate',
126 (vx, vy, vz),
127 partial(self._translate_real, vx, vy, vz),
128 partial(self._translate_reciprocal, vx, vy, vz)
129 ))
131 def add_rotate(self, v_axis: tuple[float, float, float], angle: float):
132 self.records.append(self.Record(
133 'rotate',
134 (v_axis, angle),
135 partial(self._rotate_real, v_axis, angle),
136 partial(self._rotate_reciprocal, v_axis, angle)
137 ))
139 @staticmethod
140 def _translate_real(vx, vy, vz, x, y, z, value):
141 pass
143 @staticmethod
144 def _translate_reciprocal(vx: float, vy: float, vz: float, qx: Tensor, qy: Tensor, qz: Tensor, complex_argument_addend: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
145 added_arg = -(qx*vx + qy*vy + qz*vz)
146 return qx, qy, qz, complex_argument_addend + added_arg
148 @staticmethod
149 def _rotate_real(v_axis, angle, x, y, z, value):
150 pass
152 @staticmethod
153 def _rotate_reciprocal(v_axis: tuple[float, float, float], angle: float, qx: Tensor, qy: Tensor, qz: Tensor, complex_argument_addend: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
154 rqx, rqy, rqz = calcfunc.euler_rodrigues_rotate(qx, qy, qz, v_axis, -angle)
155 return rqx, rqy, rqz, complex_argument_addend
157 def apply_real(self):
158 pass
160 def apply_reciprocal(self, qx: Tensor, qy: Tensor, qz: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
161 complex_argument_addend = torch.zeros_like(qx)
162 for rcd in reversed(self.records):
163 qx, qy, qz, complex_argument_addend = rcd.func_reciprocal(qx, qy, qz, complex_argument_addend)
164 return qx, qy, qz, complex_argument_addend
168class GridModel(Model):
170 def __init__(self, x1d: Tensor, y1d: Tensor, z1d: Tensor, sld: Tensor, device: Optional[str|torch.device] = None) -> None:
171 if device is None:
172 self.device = sld.device
173 else:
174 self.device = torch.device(device)
175 x1d, y1d, z1d, sld = x1d.to(device), y1d.to(device), z1d.to(device), sld.to(device)
176 self.real_grid = Grid(x1d, y1d, z1d, sld)
178 self.transforms = GeoTransforms()
179 self.clear_transforms() # add basic translate record for grid centering
181 def clear_transforms(self):
182 self.transforms.records.clear()
183 # 由于fft默认网格最左下角的格点在(0,0,0)位置,因此基础就要有一定的平移才是真正的模型散射振幅
184 self.translate(
185 self.real_grid.bounding_box.xmin + self.real_grid.spacing/2,
186 self.real_grid.bounding_box.ymin + self.real_grid.spacing/2,
187 self.real_grid.bounding_box.zmin + self.real_grid.spacing/2,
188 )
190 @property
191 def sld(self) -> Tensor:
192 return self.real_grid.value3d
194 @property
195 def maxq(self) -> float:
196 return self.reciprocal_grid.max_radius.item()
198 @log
199 def scatter(self, nq: Optional[int] = None, form_factor: bool = True):
200 # determine 1d grid number in reciprocal space
201 real_size_max = max(*self.real_grid.x1d.shape, *self.real_grid.y1d.shape, *self.real_grid.z1d.shape)
202 if form_factor:
203 if nq is None:
204 nq = min(600, 10*real_size_max) # in case of using too much resource
205 else:
206 nq = max(nq, real_size_max)
207 else:
208 nq = real_size_max
210 s1d = torch.fft.fftfreq(nq, d=self.real_grid.spacing, device=self.device)
211 s1d = torch.fft.fftshift(s1d)
212 s1dz = torch.fft.rfftfreq(nq, d=self.real_grid.spacing, device=self.device)
214 q1d, q1dz = 2*torch.pi*s1d, 2*torch.pi*s1dz
216 F_half = torch.fft.rfftn(self.real_grid.value3d, s=(nq, nq, nq))
217 F_half = torch.fft.fftshift(F_half, dim=(0,1))
220 ##### Continuous-density correction #####
221 # Correct discrete density to continuous density by
222 # multiplying box scattering function from a voxel.
223 # And slso eliminate the intensity difference caused by
224 # different spacing in real space.
225 # ATTENTION!
226 # Result from sphere shows applying continuous-density
227 # correction gives worse result then before, which shows
228 # larger deviation from -4 slope line.
229 d = self.real_grid.spacing
230 sinc = lambda t: torch.nan_to_num(torch.sin(t)/t, nan=1.)
231 sinc1d, sinc1dz = sinc(q1d*d/2), sinc(q1dz*d/2)
232 box_scatt = torch.einsum('i,j,k->ijk', sinc1d, sinc1d, sinc1dz)
233 F_half = F_half * d**3 * box_scatt
235 self.reciprocal_grid = ReciprocalGrid(q1d, q1d, q1dz, F_half)
238 def translate(self, vx: float, vy: float, vz: float):
239 self.transforms.add_translate(vx, vy, vz)
241 def rotate(self, v_axis: tuple[float, float, float], angle: float):
242 self.transforms.add_rotate(v_axis, angle)
244 @log
245 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
246 input_device = qx.device
247 qx, qy, qz = qx.to(self.device), qy.to(self.device), qz.to(self.device)
248 qx, qy, qz, complex_argument_addend = self.transforms.apply_reciprocal(qx, qy, qz)
249 F = self.reciprocal_grid.interpolate(qx, qy, qz)
250 F = complex_increase_argument(F, complex_argument_addend)
251 return F.to(input_device)
253 @log
254 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
255 F = self.amplitude(qx, qy, qz)
256 return F.real**2 + F.imag**2
258 @log
259 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor:
260 q1d_effective = q1d[q1d<=self.maxq]
261 N = torch.round(q1d_effective/q1d_effective[0]) + offset
263 qx, qy, qz = calcfunc.multiple_spherical_sampling(q1d_effective, N)
265 Iall = self.intensity(qx, qy, qz)
266 I = torch.zeros_like(q1d_effective)
267 begin = 0
268 for i, n in enumerate(N):
269 n = n.int().item()
270 I[i] = Iall[begin:begin+n].mean()
271 begin += n
273 I1d = torch.zeros_like(q1d)
274 I1d[q1d<=self.maxq] = I
275 I1d[q1d>self.maxq] = torch.nan
276 return I1d
279class AssemblyModel(Model):
281 def __init__(self, *grid_models: 'GridModel|AssemblyModel') -> None:
282 self.components = grid_models
284 @property
285 def maxq(self) -> float:
286 return min(*[model.maxq for model in self.components])
288 @log
289 def amplitude(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
290 F = torch.complex(torch.zeros_like(qx), torch.zeros_like(qx))
291 for model in self.components:
292 F += model.amplitude(qx, qy, qz)
293 return F
295 @log
296 def intensity(self, qx: Tensor, qy: Tensor, qz: Tensor) -> Tensor:
297 F = self.amplitude(qx, qy, qz)
298 return F.real**2 + F.imag**2
300 @log
301 def intensity_ave(self, q1d: Tensor, offset: int = 100) -> Tensor:
302 q1d_effective = q1d[q1d<=self.maxq]
303 N = torch.round(q1d_effective/q1d_effective[0]) + offset
305 qx, qy, qz = calcfunc.multiple_spherical_sampling(q1d_effective, N)
307 Iall = self.intensity(qx, qy, qz)
308 I = torch.zeros_like(q1d_effective)
309 begin = 0
310 for i, n in enumerate(N):
311 n = n.int().item()
312 I[i] = Iall[begin:begin+n].mean()
313 begin += n
315 I1d = torch.zeros_like(q1d)
316 I1d[q1d<=self.maxq] = I
317 I1d[q1d>self.maxq] = torch.nan
318 return I1d