Coverage for tests\test_module.py: 98%

127 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-09 17:12 +0800

1import copy 

2import math 

3import torch 

4from torch import Tensor 

5 

6import model2sas 

7import model2sas.calcfunc 

8 

9DEVICE = 'cuda' 

10 

11def logscale_close(t1: Tensor, t2: Tensor) -> bool: 

12 notnan = lambda t: t[~t.isnan()] 

13 num_notnan = lambda t: notnan(t).numel() 

14 ave = lambda t: t.nanmean() 

15 sigma = lambda t: torch.sqrt( ((t-ave(t))**2).nansum() / (num_notnan(t)-1) ) 

16 dev = lambda t: t-ave(t) 

17 def del_outliers(t, nsigma=3): 

18 sum_n_outlier = 0 

19 outlier = dev(t).abs() > nsigma*sigma(t) 

20 while outlier.sum().item() > 0: 

21 sum_n_outlier += outlier.sum().item() 

22 if sum_n_outlier > 0.1*t.numel(): 

23 raise RuntimeError('Too many outliers') 

24 t = torch.where(outlier, torch.nan, t) 

25 outlier = dev(t).abs() > nsigma*sigma(t) 

26 # print(sum_n_outlier) 

27 return t 

28 

29 logoffset = torch.log10(t1) - torch.log10(t2) 

30 logoffset = del_outliers(logoffset) 

31 return (sigma(logoffset)/ave(logoffset) < 0.01).item() 

32 

33 

34def F1d_sph(q1d: Tensor, R: float, sld: float = 1): 

35 u = q1d*R 

36 return (4*torch.pi*R**3)/3 * sld * (3/u**3) * (torch.sin(u) - u*torch.cos(u)) 

37 

38def I1d_sphere(q1d: Tensor, R: float, sld: float = 1) -> Tensor: 

39 return (F1d_sph(q1d, R, sld))**2 

40 

41def I1d_core_shell_sphere(q1d: Tensor, R_core: float, thickness: float, sld_core: float, sld_shell: float) -> Tensor: 

42 return (F1d_sph(q1d, R_core, sld_core) + F1d_sph(q1d, R_core+thickness, sld_shell) - F1d_sph(q1d, R_core, sld_shell))**2 

43 

44 

45def core_shell_sphere_model(R_core, thickness, sld_core, sld_shell, device=DEVICE) -> model2sas.GridModel: 

46 b1, b2 = [-(R_core+thickness)]*3, [R_core+thickness]*3 

47 bbox = model2sas.model.BoundingBox(*b1, *b2) 

48 x1d, y1d, z1d, x, y, z = model2sas.readfile.meshgrid(bbox, device=device) 

49 sld = torch.zeros_like(x) 

50 r, theta, phi = model2sas.calcfunc.convert_coord(x, y, z, 'car', 'sph') 

51 sld[r<=R_core] = sld_core 

52 sld[(r>=R_core) & (r<=R_core+thickness)] = sld_shell 

53 model = model2sas.GridModel(x1d, y1d, z1d, sld) 

54 model.scatter() 

55 return model 

56 

57 

58def test_gridmodel(): 

59 # geo parameters 

60 R_core, thickness, sld_core, sld_shell = 10, 5, -2, 1 

61 model = core_shell_sphere_model(R_core, thickness, sld_core, sld_shell) 

62 model.translate(32, 57, 101) 

63 model.rotate((-2.1, 1.5, 10.4), 0.41*torch.pi) 

64 

65 q = torch.linspace(0.01, 2, steps=1000) 

66 I = model.intensity_ave(q) 

67 

68 # theoretical 

69 Itheo = I1d_core_shell_sphere(q, R_core, thickness, sld_core, sld_shell) 

70 

71 fig = model2sas.Figure() 

72 fig.plot_curve1d(q, Itheo, name='theoretical') 

73 fig.plot_curve1d(q, I, name='model2sas') 

74 fig.show() 

75 

76 assert logscale_close(I, Itheo), 'Model2SAS calculated inconsistent with theoretical result' 

77 

78 

79def test_assemblymodel_and_transform(): 

80 

81 R, H, sld_value1, sld_value2 = 10, 30, -1, 1.5 

82 

83 jcyl = model2sas.read_math('./tests/modelfiles/joined_cylinder.py', R=R, H=H, sld_value1=sld_value1, sld_value2=sld_value2, device=DEVICE) 

84 jcyl.scatter() 

85 

86 d = 5*H 

87 cyl1 = model2sas.read_math('./tests/modelfiles/cylinder.py', R=R, H=H/2, sld_value=sld_value1, device=DEVICE) 

88 cyl1.scatter() 

89 cyl1.translate(0, 0, -d) 

90 cyl1.rotate((0, 3, 0), torch.pi/4) 

91 

92 cyl2 = model2sas.read_math('./tests/modelfiles/cylinder.py', R=R, H=H/2, sld_value=sld_value2, device=DEVICE) 

93 cyl2.scatter() 

94 cyl2.rotate((0, -0.7, 0), -torch.pi/4) 

95 sqrt2 = math.sqrt(2) 

96 cyl2.translate(-d/sqrt2, 0, -d/sqrt2) 

97 cyl2.translate(H/2/sqrt2, 0, H/2/sqrt2) 

98 

99 assm = model2sas.AssemblyModel(cyl1, cyl2) 

100 

101 q = torch.linspace(0.001, 2, 500) 

102 Ij = jcyl.intensity_ave(q) 

103 Ia = assm.intensity_ave(q) 

104 

105 fig = model2sas.Figure() 

106 fig.plot_curve1d(q, Ij, name='single') 

107 fig.plot_curve1d(q, Ia, name='2 assembled') 

108 fig.show() 

109 

110 assert logscale_close(Ij, Ia) 

111 

112 

113 

114def test_readfile(): 

115 

116 def normal_flow(model): 

117 model.scatter() 

118 model.translate(32, 57, 101) 

119 model.rotate((-2.1, 1.5, 10.4), 0.41*torch.pi) 

120 q = torch.linspace(0.01, 2, steps=1000) 

121 model.intensity_ave(q) 

122 

123 model = model2sas.readfile.read_stl('./tests/modelfiles/torus.stl', centering=True, device=DEVICE) 

124 normal_flow(model) 

125 

126 model = model2sas.readfile.read_pdb('./tests/modelfiles/3v03.pdb', device=DEVICE) 

127 normal_flow(model) 

128 model = model2sas.readfile.read_pdb('./tests/modelfiles/3v03.pdb', probe='neutron', device=DEVICE) 

129 normal_flow(model) 

130 

131 model = model2sas.readfile.read_math('./tests/modelfiles/core_shell_sphere.py', R_core=20, thickness=30, sld_core=2, sld_shell=-1, device=DEVICE) 

132 normal_flow(model) 

133 

134 MathModelClass = model2sas.readfile.import_mathmodel_class('./tests/modelfiles/core_shell_sphere.py') 

135 mathmodel = MathModelClass() 

136 model = model2sas.readfile.read_math(mathmodel, R_core=20, thickness=30, sld_core=2, sld_shell=-1, device=DEVICE) 

137 normal_flow(model) 

138 

139 

140def test_plot(): 

141 R_core, thickness, sld_core, sld_shell = 10, 5, -2, 1 

142 model = core_shell_sphere_model(R_core, thickness, sld_core, sld_shell) 

143 x, y, z = model.real_grid.coord3d 

144 sld = model.sld 

145 

146 fig = model2sas.Figure() 

147 fig.plot_volume3d(x, y, z, sld) 

148 fig.show() 

149 

150 fig = model2sas.Figure() 

151 fig.plot_voxel3d(x[sld!=0], y[sld!=0], z[sld!=0], model.real_grid.spacing) 

152 fig.show() 

153 

154 fig = model2sas.Figure() 

155 q = torch.linspace(0.01, 2, steps=1000) 

156 I = model.intensity_ave(q) 

157 fig.plot_curve1d(q, I) 

158 fig.show() 

159 

160 fig = model2sas.Figure() 

161 q1d = torch.linspace(-1, 1, 50) 

162 qx, qy, qz = torch.meshgrid(q1d, q1d, q1d, indexing='ij') 

163 I = model.intensity(qx, qy, qz) 

164 fig.plot_volume3d(qx, qy, qz, I) 

165 fig.show() 

166 

167 fig = model2sas.Figure() 

168 fig.plot_surface2d(I[:,:,0]) 

169 fig.show() 

170 

171 

172if __name__ == '__main__': 

173 test_assemblymodel_and_transform() 

174