Coverage for src\model2sas\calcfunc.py: 92%

106 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-09 17:12 +0800

1"""Compute-related functions in model2sas. 

2All based on pytorch instead of numpy. 

3""" 

4 

5from typing import Literal 

6import math 

7 

8# import numpy as np 

9import torch 

10from torch import Tensor 

11 

12from .utils import log 

13from .global_vars import PRECISION 

14 

15 

16class __CoordConverter: 

17 """Pre-store the according functions. 

18 """ 

19 def __init__(self) -> None: 

20 self.convert_func = { 

21 'car2car': self.nochange, 

22 'car2sph': self.car2sph, 

23 'car2cyl': self.car2cyl, 

24 'sph2car': self.sph2car, 

25 'sph2sph': self.nochange, 

26 'sph2cyl': lambda a, b, c: self.car2cyl(*self.sph2car(a, b, c)), 

27 'cyl2car': self.cyl2car, 

28 'cyl2sph': lambda a, b, c: self.car2sph(*self.cyl2car(a, b, c)), 

29 'cyl2cyl': self.nochange, 

30 } 

31 

32 # @log 

33 def __call__(self, u:Tensor, v:Tensor, w:Tensor, original_coord: Literal['car', 'sph', 'cyl'], target_coord: Literal['car', 'sph', 'cyl']) -> tuple[Tensor, Tensor, Tensor]: 

34 return self.convert_func[f'{original_coord}2{target_coord}'](u, v, w) 

35 

36 @staticmethod 

37 def nochange(u:Tensor, v:Tensor, w:Tensor) -> tuple[Tensor, Tensor, Tensor]: 

38 return u, v, w 

39 @staticmethod 

40 def car2sph(x:Tensor, y:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]: 

41 r = torch.sqrt(x**2 + y**2 + z**2) 

42 phi = torch.arccos(z/r) # when r=0, output phi=nan 

43 phi = torch.nan_to_num(phi, nan=0.) # convert nan to 0 

44 theta = torch.arctan2(y, x) # range [-pi, pi] 

45 theta = torch.where(theta<0, theta+2*torch.pi, theta) # convert range to [0, 2pi] 

46 return r, theta, phi 

47 @staticmethod 

48 def car2cyl(x:Tensor, y:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]: 

49 rho = torch.sqrt(x**2+y**2) 

50 theta = torch.arctan2(y, x) # range [-pi, pi] 

51 theta = theta + (1-torch.sign(torch.sign(theta)+1))*2*torch.pi # convert range to [0, 2pi] 

52 return rho, theta, z 

53 @staticmethod 

54 def sph2car(r:Tensor, theta:Tensor, phi:Tensor) -> tuple[Tensor, Tensor, Tensor]: 

55 sinphi = torch.sin(phi) 

56 x = r * torch.cos(theta) * sinphi 

57 y = r * torch.sin(theta) * sinphi 

58 z = r * torch.cos(phi) 

59 return x, y, z 

60 @staticmethod 

61 def cyl2car(rho:Tensor, theta:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]: 

62 x = rho * torch.cos(theta) 

63 y = rho * torch.sin(theta) 

64 return x, y, z 

65 

66 

67# use just like a function 

68convert_coord = __CoordConverter() 

69 

70 

71@log 

72def moller_trumbore_intersect_count(points: Tensor, ray: Tensor, triangles: Tensor) -> Tensor: 

73 """Calculate all the points intersect with all triangles seperately 

74 using Möller-Trumbore intersection algorithm. 

75 See paper https://doi.org/10.1080/10867651.1997.10487468  

76 All variable names follow this paper. 

77 

78 Args: 

79 points (Tensor): shape=(n1, ..., ni, 3), 3D meshgrid points coordinates 

80 ray (Tensor): shape=(3,) 

81 triangles (Tensor): shape=(m, 3, 3), m triangles with 3 vertices, each (3,) coordinates 

82 

83 Returns: 

84 Tensor: shape=(n1, ..., ni), indicate intersect counts per point 

85 """ 

86 #* Highest performance for now, especially on GPU. 

87 #* No using python loops. 

88 

89 O = points # (n1, ..., ni, 3) 

90 D = ray # (3,) 

91 V0 = triangles[:,0,:] # (m, 3) 

92 E1 = triangles[:,1,:] - V0 # (m, 3) 

93 E2 = triangles[:,2,:] - V0 # (m, 3) 

94 

95 points_coord_dim = (1,) * (O.dim()-1) # to support any input shape 

96 

97 T = O.unsqueeze(-2) - V0.view(*points_coord_dim, *V0.shape) # (n1, ..., ni, m, 3) 

98 P = torch.linalg.cross(D.unsqueeze(0), E2, dim=-1) # (m, 3) 

99 Q = torch.linalg.cross(T, E1.view(*points_coord_dim, *E1.shape), dim=-1) # (n1, ..., ni, m, 3) 

100 

101 PE1_reciprocal = 1 / torch.linalg.vecdot(P, E1, dim=-1) # (m,) 

102 QE2 = torch.linalg.vecdot(Q, E2.view(*points_coord_dim, *E2.shape), dim=-1) # (n1, ..., ni, m) 

103 PT = torch.linalg.vecdot(P.view(*points_coord_dim, *P.shape), T, dim=-1) # (n1, ..., ni, m) 

104 QD = torch.linalg.vecdot(Q, D, dim=-1) # (n1, ..., ni, m) 

105 

106 PE1_reciprocal = PE1_reciprocal.view(*points_coord_dim, *PE1_reciprocal.shape) # (1, 1, 1, m) 

107 t = PE1_reciprocal * QE2 # (n1, ..., ni, m) 

108 u = PE1_reciprocal * PT # (n1, ..., ni, m) 

109 v = PE1_reciprocal * QD # (n1, ..., ni, m) 

110 

111 intersect = torch.zeros_like(t, dtype=torch.int32) 

112 intersect[(t>0) & (u>0) & (v>0) & ((u+v)<1)] = 1 # (n1, ..., ni, m) 

113 intersect_count = intersect.sum(-1) # (n1, ..., ni) 

114 return intersect_count 

115 

116 

117 

118@log 

119def complex_increase_argument(complex: Tensor, argument_addend: Tensor|float) -> Tensor: 

120 """Args: 

121 complex (Tensor): complex dtype tensor 

122 argument_addend (Tensor): same shape as complex, or a float apply to all 

123 

124 Returns: 

125 Tensor: complex dtype tensor 

126 """ 

127 mod, arg = torch.sqrt(complex.real**2 + complex.imag**2), torch.arctan2(complex.imag, complex.real) 

128 arg = arg + argument_addend 

129 return mod * torch.complex(torch.cos(arg), torch.sin(arg)) 

130 

131 

132# @log 

133# def nearest_interp(x:Tensor, y:Tensor, z:Tensor, px:Tensor, py:Tensor, pz:Tensor, c:Tensor, d:float | Tensor) -> Tensor: 

134# """Conduct nearest interpolate on equally spaced meshgrid. 

135# 当网格值c是复数时等效于对实部和虚部分别进行插值 

136 

137# Args: 

138# x (Tensor): any shape, x coordinates of points to be interpolated 

139# y (Tensor): any shape, y coordinates of points to be interpolated 

140# z (Tensor): any shape, z coordinates of points to be interpolated 

141# px (Tensor): shape=(m1,), x1d grid of meshgrid with known values 

142# py (Tensor): shape=(m2,), y1d grid of meshgrid with known values 

143# pz (Tensor): shape=(m3,), z1d grid of meshgrid with known values 

144# c (Tensor): shape=(m1, m2, m3), values of each of in meshgrid(px, py, pz) 

145# d (float | Tensor): spacing of meshgrid(px, py, pz), equally spaced 

146 

147# Returns: 

148# Tensor: same shape as x|y|z, interpolated values of (x, y, z) 

149# """ 

150# ix, iy, iz = (x-px[0]+d/2)/d, (y-py[0]+d/2)/d, (z-pz[0]+d/2)/d 

151# ix, iy, iz = ix.to(torch.int64), iy.to(torch.int64), iz.to(torch.int64) # tensors used as indices must be long, byte or bool tensors 

152# c_interp = c[ix, iy, iz] 

153# return c_interp 

154 

155@log 

156def trilinear_interp(x:Tensor, y:Tensor, z:Tensor, px:Tensor, py:Tensor, pz:Tensor, c:Tensor, d:float | Tensor) -> Tensor: 

157 """Conduct trilinear interpolate on equally spaced meshgrid. 

158 当网格值c是复数时等效于对实部和虚部分别进行插值 

159 

160 Args: 

161 x (Tensor): any shape, x coordinates of points to be interpolated 

162 y (Tensor): any shape, y coordinates of points to be interpolated 

163 z (Tensor): any shape, z coordinates of points to be interpolated 

164 px (Tensor): shape=(m1,), x1d grid of meshgrid with known values 

165 py (Tensor): shape=(m2,), y1d grid of meshgrid with known values 

166 pz (Tensor): shape=(m3,), z1d grid of meshgrid with known values 

167 c (Tensor): shape=(m1, m2, m3), values of each of in meshgrid(px, py, pz) 

168 d (float | Tensor): spacing of meshgrid(px, py, pz), equally spaced 

169 

170 Returns: 

171 Tensor: same shape as x|y|z, interpolated values of (x, y, z) 

172 """ 

173 ix, iy, iz = (x-px[0])/d, (y-py[0])/d, (z-pz[0])/d 

174 ix, iy, iz = ix.to(torch.int64), iy.to(torch.int64), iz.to(torch.int64) # tensors used as indices must be long, byte or bool tensors 

175 

176 x0, y0, z0 = px[ix], py[iy], pz[iz] 

177 x1, y1, z1 = px[ix+1], py[iy+1], pz[iz+1] 

178 xd, yd, zd = (x-x0)/(x1-x0), (y-y0)/(y1-y0), (z-z0)/(z1-z0) 

179 

180 c_interp = c[ix, iy, iz]*(1-xd)*(1-yd)*(1-zd) \ 

181 + c[ix+1, iy, iz]*xd*(1-yd)*(1-zd) \ 

182 + c[ix, iy+1, iz]*(1-xd)*yd*(1-zd) \ 

183 + c[ix, iy, iz+1]*(1-xd)*(1-yd)*zd \ 

184 + c[ix+1, iy, iz+1]*xd*(1-yd)*zd \ 

185 + c[ix, iy+1, iz+1]*(1-xd)*yd*zd \ 

186 + c[ix+1, iy+1, iz]*xd*yd*(1-zd) \ 

187 + c[ix+1, iy+1, iz+1]*xd*yd*zd 

188 return c_interp 

189 

190 

191@log 

192def euler_rodrigues_rotate(x: Tensor, y: Tensor, z: Tensor, v_axis: tuple[float, float, float], angle: float) -> tuple[Tensor, Tensor, Tensor]: 

193 """Central rotation of coordinates by Euler-Rodrigues formula. 

194 Refer to https://en.wikipedia.org/wiki/Euler%E2%80%93Rodrigues_formula 

195 

196 Args: 

197 x (Tensor): any shape, x coordinates of points to be rotated  

198 y (Tensor): any shape, y coordinates of points to be rotated  

199 z (Tensor): any shape, z coordinates of points to be rotated  

200 v_axis (tuple[float, float, float]): axis of rotation 

201 angle (float): in radian 

202 

203 Returns: 

204 tuple[Tensor, Tensor, Tensor]: rotated coordinated rx, ry, yz, same shape and device as input. 

205 """ 

206 l = math.dist(v_axis, (0,0,0)) 

207 

208 a = math.cos(angle/2) 

209 b = v_axis[0]/l * math.sin(angle/2) # axis vector should be unit vector 

210 c = v_axis[1]/l * math.sin(angle/2) 

211 d = v_axis[2]/l * math.sin(angle/2) 

212 

213 # below is faster than vector formulation by my test  

214 rx = (a**2 + b**2 - c**2 - d**2) * x \ 

215 +(2*(b*c - a*d)) * y \ 

216 +(2*(b*d + a*c)) * z 

217 

218 ry = (2*(b*c + a*d)) * x \ 

219 +(a**2 + c**2 - b**2 - d**2) * y \ 

220 +(2*(c*d - a*b)) * z 

221 

222 rz = (2*(b*d - a*c)) * x \ 

223 +(2*(c*d + a*b)) * y \ 

224 +(a**2 + d**2 - b**2 - c**2) * z 

225 

226 return rx, ry, rz 

227 

228 

229@log 

230def multiple_spherical_sampling(Rs: Tensor, Ns: Tensor) -> tuple[Tensor, Tensor, Tensor]: 

231 """Generate sampling points using fibonacci grid 

232 for multiple spherical shells. r and N are 1D tensor 

233 with same shape. 

234 

235 Args: 

236 r (Tensor): radius of spherical shells 

237 N (Tensor): number of points on each shell 

238 

239 Returns: 

240 tuple[Tensor, Tensor, Tensor]: x, y, z coordinates 

241 """ 

242 

243 # for better performance: 

244 # lesser loop, lesser work inside loop 

245 def gen_nNr(R, N): 

246 n = torch.arange(1, N+1, dtype=PRECISION) 

247 N = torch.full_like(n, N) 

248 r = torch.full_like(n, R) 

249 return n, N, r 

250 nNr = [gen_nNr(R, N) for R, N in zip(Rs, Ns)] 

251 n = torch.cat([t[0] for t in nNr]) 

252 N = torch.cat([t[1] for t in nNr]) 

253 r = torch.cat([t[2] for t in nNr]) 

254 

255 phi = (torch.sqrt(torch.tensor(5, dtype=PRECISION))-1)/2 

256 z = ((2*n-1)/N - 1) 

257 x = r * torch.sqrt(1-z**2) * torch.cos(2*torch.pi*n*phi) 

258 y = r * torch.sqrt(1-z**2) * torch.sin(2*torch.pi*n*phi) 

259 z = r * z 

260 return x, y, z 

261 

262 

263