Coverage for src\model2sas\utils.py: 33%

128 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-09 17:12 +0800

1"""some useful utility functions 

2""" 

3 

4from typing import Literal 

5 

6import time 

7import sys 

8import functools 

9 

10import torch 

11from torch import Tensor 

12from loguru import logger 

13 

14from . import global_vars 

15 

16func_tier: int = 0 

17 

18def add_logger(sink=sys.stderr, format=global_vars.LOG_FORMAT_STR, **kwargs): 

19 logger.add(sink, format=format, **kwargs) 

20 

21logger.remove(0) 

22logger.add(sys.stderr, format=global_vars.LOG_FORMAT_STR) 

23 

24def log(func): 

25 @functools.wraps(func) 

26 def wrapper(*args, **kwargs): 

27 if global_vars.PRINT_LOG: 

28 global func_tier 

29 func_tier += 1 

30 if (global_vars.LOG_LEVEL<0 or func_tier<=global_vars.LOG_LEVEL): 

31 logger.info(f'[{" ":>11}] {"| "*func_tier}○ {func.__name__}') 

32 start_time = time.perf_counter() 

33 result = func(*args, **kwargs) 

34 time_cost = time.perf_counter() - start_time 

35 if (global_vars.LOG_LEVEL<0 or func_tier<=global_vars.LOG_LEVEL): 

36 logger.success(f'[{time_cost:>9.6f} s] {"| "*func_tier}● {func.__name__}') 

37 func_tier -= 1 

38 else: 

39 result = func(*args, **kwargs) 

40 return result 

41 return wrapper 

42 

43 

44class Detector: 

45 '''Simulation of a 2d detector. 

46 In a coordinate system where sample position as origin, 

47 beam direction as positive Y axis. 

48 All length unit should be meter except wavelength. 

49 Output q unit will be reverse wavelength unit. 

50 ''' 

51 def __init__(self, resolution: tuple[int, int], pixel_size: float) -> None: 

52 x = torch.arange(resolution[0], dtype=torch.float32) 

53 z = torch.arange(resolution[1], dtype=torch.float32) 

54 x, z = pixel_size*x, pixel_size*z 

55 cx, cz = (x[0]+x[-1])/2, (z[0]+z[-1])/2 

56 x, z = x - cx, z - cz 

57 x, z = torch.meshgrid(x, z, indexing='ij') 

58 y = torch.zeros_like(x, dtype=torch.float32) 

59 self.x, self.y, self.z = x, y, z 

60 self.pitch_axis = torch.tensor((1,0,0), dtype=torch.float32) 

61 self.yaw_axis = torch.tensor((0,0,1), dtype=torch.float32) 

62 self.roll_axis = torch.tensor((0,1,0), dtype=torch.float32) 

63 self.sdd = 0. 

64 self.resolution = resolution 

65 self.pixel_size = pixel_size 

66 

67 def get_center(self) -> Tensor: 

68 cx = (self.x[0,0] + self.x[-1,-1]) / 2 

69 cy = (self.y[0,0] + self.y[-1,-1]) / 2 

70 cz = (self.z[0,0] + self.z[-1,-1]) / 2 

71 return torch.tensor((cx, cy, cz)) 

72 

73 def set_sdd(self, sdd: float) -> None: 

74 delta_sdd = sdd - self.sdd 

75 self.y = self.y + delta_sdd 

76 self.sdd = sdd 

77 

78 def translate(self, vx: float, vz: float, vy: float = 0.) -> None: 

79 self.x = self.x + vx 

80 self.z = self.z + vz 

81 self.y = self.y + vy 

82 self.sdd = self.sdd + vy 

83 

84 def _euler_rodrigues_rotate(self, coord: Tensor, axis: Tensor, angle: float) -> Tensor: 

85 '''Rotate coordinates by euler rodrigues rotate formula. 

86 coord.shape = (n,3) 

87 ''' 

88 ax = axis / torch.sqrt(torch.sum(axis**2)) 

89 ang = torch.tensor(angle) 

90 a = torch.cos(ang/2) 

91 w = ax * torch.sin(ang/2) 

92 

93 x = coord 

94 wx = -torch.linalg.cross(x, w.expand_as(x), dim=-1) 

95 x_rotated = x + 2*a*wx + 2*(-torch.linalg.cross(wx, w.expand_as(wx), dim=-1)) 

96 return x_rotated 

97 

98 def _rotate(self, rotation_type: Literal['pitch', 'yaw', 'roll'], angle: float) -> None: 

99 if rotation_type == 'pitch': 

100 axis = self.pitch_axis 

101 self.yaw_axis = self._euler_rodrigues_rotate(self.yaw_axis, axis, angle) 

102 self.roll_axis = self._euler_rodrigues_rotate(self.roll_axis, axis, angle) 

103 elif rotation_type == 'yaw': 

104 axis = self.yaw_axis 

105 self.pitch_axis = self._euler_rodrigues_rotate(self.pitch_axis, axis, angle) 

106 self.roll_axis = self._euler_rodrigues_rotate(self.roll_axis, axis, angle) 

107 elif rotation_type == 'roll': 

108 axis = self.roll_axis 

109 self.pitch_axis = self._euler_rodrigues_rotate(self.pitch_axis, axis, angle) 

110 self.yaw_axis = self._euler_rodrigues_rotate(self.yaw_axis, axis, angle) 

111 else: 

112 raise ValueError('Unsupported rotation type: {}'.format(rotation_type)) 

113 center = self.get_center() 

114 

115 x1d, y1d, z1d = self.x.flatten(), self.y.flatten(), self.z.flatten() 

116 coord = torch.stack((x1d, y1d, z1d), dim=1) 

117 coord = coord - center 

118 rotated_coord = self._euler_rodrigues_rotate( 

119 coord, axis, angle 

120 ) 

121 rotated_coord = rotated_coord + center 

122 x, y, z = torch.unbind(rotated_coord, dim=-1) 

123 self.x, self.y, self.z = x.reshape(self.resolution), y.reshape(self.resolution), z.reshape(self.resolution) 

124 

125 def pitch(self, angle: float) -> None: 

126 self._rotate('pitch', angle) 

127 

128 def yaw(self, angle: float) -> None: 

129 self._rotate('yaw', angle) 

130 

131 def roll(self, angle: float) -> None: 

132 self._rotate('roll', angle) 

133 

134 def _real_coord_to_reciprocal_coord(self, x: Tensor, y: Tensor, z: Tensor) -> tuple[Tensor, Tensor, Tensor]: 

135 '''In a coordinate system where sample position as origin, 

136 beam direction as positive Y axis, calculate the corresponding 

137 reciprocal coordinates (without multiply wave vector 

138 k=2pi/wavelength) by coordinates (x,y,z) in this space. 

139 ''' 

140 mod = torch.sqrt(x**2 + y**2 + z**2) 

141 unit_vector_ks_x, unit_vector_ks_y, unit_vector_ks_z = x/mod, y/mod, z/mod 

142 unit_vector_ki_x, unit_vector_ki_y, unit_vector_ki_z = 0., 1., 0. 

143 rx = unit_vector_ks_x - unit_vector_ki_x 

144 ry = unit_vector_ks_y - unit_vector_ki_y 

145 rz = unit_vector_ks_z - unit_vector_ki_z 

146 return rx, ry, rz 

147 

148 def get_reciprocal_coord(self, wavelength: float) -> tuple[Tensor, Tensor, Tensor]: 

149 k = 2*torch.pi / wavelength 

150 rx, ry, rz = self._real_coord_to_reciprocal_coord(self.x, self.y, self.z) 

151 qx, qy, qz = k*rx, k*ry, k*rz 

152 return qx, qy, qz 

153 

154 def get_q_range(self, wavelength: float) -> tuple[float, float]: 

155 qx, qy, qz = self.get_reciprocal_coord(wavelength) 

156 q = torch.sqrt(qx**2 + qy**2 + qz**2) 

157 return q.min().item(), q.max().item() 

158 

159 def get_beamstop_mask(self, d: float) -> Tensor: 

160 '''pattern must have the same shape as self.x, y, z 

161 ''' 

162 mask = torch.ones(self.resolution, dtype=torch.float32) 

163 mask[(self.x**2+self.z**2) <= (d/2)**2] = 0. 

164 return mask 

165 

166 

167def save_pdb(filename: str, x: Tensor, y: Tensor, z:Tensor, sld: Tensor, atom_name: str = 'CA', temperature_factor: float = 0.0, element_symbol: str = 'C') -> None: 

168 """Convert a lattice model to a pdb file 

169 for calculation by other software like CRYSOL. 

170 Only preserve sld!=0 points with uniform sld value. 

171 

172 Args: 

173 x (Tensor): x coordinates 

174 y (Tensor): x coordinates 

175 z (Tensor): x coordinates 

176 sld (Tensor): sld values at each coordinates 

177 filename (str): output pdb file name 

178 """ 

179 sld = sld.flatten() 

180 x = x.flatten()[sld!=0.] 

181 y = y.flatten()[sld!=0.] 

182 z = z.flatten()[sld!=0.] 

183 sld = sld[sld!=0] 

184 

185 lines = ['REMARK 265 EXPERIMENT TYPE: THEORETICAL MODELLING\n'] 

186 for i, (xi, yi, zi) in enumerate(zip(x, y, z)): 

187 lines.append( 

188 f'{"ATOM":<6}{i+1:>5d} {atom_name:<4} {"ASP":>3} {"A":>1}{1:>4d} {xi:>8.3f}{yi:>8.3f}{zi:>8.3f}{1.0:>6.2f}{temperature_factor:>6.2f} {element_symbol:>2} \n' 

189 ) 

190 with open(filename, 'w', encoding='utf-8') as f: 

191 f.writelines(lines)