Coverage for src\model2sas\utils.py: 33%
128 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
1"""some useful utility functions
2"""
4from typing import Literal
6import time
7import sys
8import functools
10import torch
11from torch import Tensor
12from loguru import logger
14from . import global_vars
16func_tier: int = 0
18def add_logger(sink=sys.stderr, format=global_vars.LOG_FORMAT_STR, **kwargs):
19 logger.add(sink, format=format, **kwargs)
21logger.remove(0)
22logger.add(sys.stderr, format=global_vars.LOG_FORMAT_STR)
24def log(func):
25 @functools.wraps(func)
26 def wrapper(*args, **kwargs):
27 if global_vars.PRINT_LOG:
28 global func_tier
29 func_tier += 1
30 if (global_vars.LOG_LEVEL<0 or func_tier<=global_vars.LOG_LEVEL):
31 logger.info(f'[{" ":>11}] {"| "*func_tier}○ {func.__name__}')
32 start_time = time.perf_counter()
33 result = func(*args, **kwargs)
34 time_cost = time.perf_counter() - start_time
35 if (global_vars.LOG_LEVEL<0 or func_tier<=global_vars.LOG_LEVEL):
36 logger.success(f'[{time_cost:>9.6f} s] {"| "*func_tier}● {func.__name__}')
37 func_tier -= 1
38 else:
39 result = func(*args, **kwargs)
40 return result
41 return wrapper
44class Detector:
45 '''Simulation of a 2d detector.
46 In a coordinate system where sample position as origin,
47 beam direction as positive Y axis.
48 All length unit should be meter except wavelength.
49 Output q unit will be reverse wavelength unit.
50 '''
51 def __init__(self, resolution: tuple[int, int], pixel_size: float) -> None:
52 x = torch.arange(resolution[0], dtype=torch.float32)
53 z = torch.arange(resolution[1], dtype=torch.float32)
54 x, z = pixel_size*x, pixel_size*z
55 cx, cz = (x[0]+x[-1])/2, (z[0]+z[-1])/2
56 x, z = x - cx, z - cz
57 x, z = torch.meshgrid(x, z, indexing='ij')
58 y = torch.zeros_like(x, dtype=torch.float32)
59 self.x, self.y, self.z = x, y, z
60 self.pitch_axis = torch.tensor((1,0,0), dtype=torch.float32)
61 self.yaw_axis = torch.tensor((0,0,1), dtype=torch.float32)
62 self.roll_axis = torch.tensor((0,1,0), dtype=torch.float32)
63 self.sdd = 0.
64 self.resolution = resolution
65 self.pixel_size = pixel_size
67 def get_center(self) -> Tensor:
68 cx = (self.x[0,0] + self.x[-1,-1]) / 2
69 cy = (self.y[0,0] + self.y[-1,-1]) / 2
70 cz = (self.z[0,0] + self.z[-1,-1]) / 2
71 return torch.tensor((cx, cy, cz))
73 def set_sdd(self, sdd: float) -> None:
74 delta_sdd = sdd - self.sdd
75 self.y = self.y + delta_sdd
76 self.sdd = sdd
78 def translate(self, vx: float, vz: float, vy: float = 0.) -> None:
79 self.x = self.x + vx
80 self.z = self.z + vz
81 self.y = self.y + vy
82 self.sdd = self.sdd + vy
84 def _euler_rodrigues_rotate(self, coord: Tensor, axis: Tensor, angle: float) -> Tensor:
85 '''Rotate coordinates by euler rodrigues rotate formula.
86 coord.shape = (n,3)
87 '''
88 ax = axis / torch.sqrt(torch.sum(axis**2))
89 ang = torch.tensor(angle)
90 a = torch.cos(ang/2)
91 w = ax * torch.sin(ang/2)
93 x = coord
94 wx = -torch.linalg.cross(x, w.expand_as(x), dim=-1)
95 x_rotated = x + 2*a*wx + 2*(-torch.linalg.cross(wx, w.expand_as(wx), dim=-1))
96 return x_rotated
98 def _rotate(self, rotation_type: Literal['pitch', 'yaw', 'roll'], angle: float) -> None:
99 if rotation_type == 'pitch':
100 axis = self.pitch_axis
101 self.yaw_axis = self._euler_rodrigues_rotate(self.yaw_axis, axis, angle)
102 self.roll_axis = self._euler_rodrigues_rotate(self.roll_axis, axis, angle)
103 elif rotation_type == 'yaw':
104 axis = self.yaw_axis
105 self.pitch_axis = self._euler_rodrigues_rotate(self.pitch_axis, axis, angle)
106 self.roll_axis = self._euler_rodrigues_rotate(self.roll_axis, axis, angle)
107 elif rotation_type == 'roll':
108 axis = self.roll_axis
109 self.pitch_axis = self._euler_rodrigues_rotate(self.pitch_axis, axis, angle)
110 self.yaw_axis = self._euler_rodrigues_rotate(self.yaw_axis, axis, angle)
111 else:
112 raise ValueError('Unsupported rotation type: {}'.format(rotation_type))
113 center = self.get_center()
115 x1d, y1d, z1d = self.x.flatten(), self.y.flatten(), self.z.flatten()
116 coord = torch.stack((x1d, y1d, z1d), dim=1)
117 coord = coord - center
118 rotated_coord = self._euler_rodrigues_rotate(
119 coord, axis, angle
120 )
121 rotated_coord = rotated_coord + center
122 x, y, z = torch.unbind(rotated_coord, dim=-1)
123 self.x, self.y, self.z = x.reshape(self.resolution), y.reshape(self.resolution), z.reshape(self.resolution)
125 def pitch(self, angle: float) -> None:
126 self._rotate('pitch', angle)
128 def yaw(self, angle: float) -> None:
129 self._rotate('yaw', angle)
131 def roll(self, angle: float) -> None:
132 self._rotate('roll', angle)
134 def _real_coord_to_reciprocal_coord(self, x: Tensor, y: Tensor, z: Tensor) -> tuple[Tensor, Tensor, Tensor]:
135 '''In a coordinate system where sample position as origin,
136 beam direction as positive Y axis, calculate the corresponding
137 reciprocal coordinates (without multiply wave vector
138 k=2pi/wavelength) by coordinates (x,y,z) in this space.
139 '''
140 mod = torch.sqrt(x**2 + y**2 + z**2)
141 unit_vector_ks_x, unit_vector_ks_y, unit_vector_ks_z = x/mod, y/mod, z/mod
142 unit_vector_ki_x, unit_vector_ki_y, unit_vector_ki_z = 0., 1., 0.
143 rx = unit_vector_ks_x - unit_vector_ki_x
144 ry = unit_vector_ks_y - unit_vector_ki_y
145 rz = unit_vector_ks_z - unit_vector_ki_z
146 return rx, ry, rz
148 def get_reciprocal_coord(self, wavelength: float) -> tuple[Tensor, Tensor, Tensor]:
149 k = 2*torch.pi / wavelength
150 rx, ry, rz = self._real_coord_to_reciprocal_coord(self.x, self.y, self.z)
151 qx, qy, qz = k*rx, k*ry, k*rz
152 return qx, qy, qz
154 def get_q_range(self, wavelength: float) -> tuple[float, float]:
155 qx, qy, qz = self.get_reciprocal_coord(wavelength)
156 q = torch.sqrt(qx**2 + qy**2 + qz**2)
157 return q.min().item(), q.max().item()
159 def get_beamstop_mask(self, d: float) -> Tensor:
160 '''pattern must have the same shape as self.x, y, z
161 '''
162 mask = torch.ones(self.resolution, dtype=torch.float32)
163 mask[(self.x**2+self.z**2) <= (d/2)**2] = 0.
164 return mask
167def save_pdb(filename: str, x: Tensor, y: Tensor, z:Tensor, sld: Tensor, atom_name: str = 'CA', temperature_factor: float = 0.0, element_symbol: str = 'C') -> None:
168 """Convert a lattice model to a pdb file
169 for calculation by other software like CRYSOL.
170 Only preserve sld!=0 points with uniform sld value.
172 Args:
173 x (Tensor): x coordinates
174 y (Tensor): x coordinates
175 z (Tensor): x coordinates
176 sld (Tensor): sld values at each coordinates
177 filename (str): output pdb file name
178 """
179 sld = sld.flatten()
180 x = x.flatten()[sld!=0.]
181 y = y.flatten()[sld!=0.]
182 z = z.flatten()[sld!=0.]
183 sld = sld[sld!=0]
185 lines = ['REMARK 265 EXPERIMENT TYPE: THEORETICAL MODELLING\n']
186 for i, (xi, yi, zi) in enumerate(zip(x, y, z)):
187 lines.append(
188 f'{"ATOM":<6}{i+1:>5d} {atom_name:<4} {"ASP":>3} {"A":>1}{1:>4d} {xi:>8.3f}{yi:>8.3f}{zi:>8.3f}{1.0:>6.2f}{temperature_factor:>6.2f} {element_symbol:>2} \n'
189 )
190 with open(filename, 'w', encoding='utf-8') as f:
191 f.writelines(lines)