Coverage for src\model2sas\calcfunc.py: 92%
106 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-09 17:12 +0800
1"""Compute-related functions in model2sas.
2All based on pytorch instead of numpy.
3"""
5from typing import Literal
6import math
8# import numpy as np
9import torch
10from torch import Tensor
12from .utils import log
13from .global_vars import PRECISION
16class __CoordConverter:
17 """Pre-store the according functions.
18 """
19 def __init__(self) -> None:
20 self.convert_func = {
21 'car2car': self.nochange,
22 'car2sph': self.car2sph,
23 'car2cyl': self.car2cyl,
24 'sph2car': self.sph2car,
25 'sph2sph': self.nochange,
26 'sph2cyl': lambda a, b, c: self.car2cyl(*self.sph2car(a, b, c)),
27 'cyl2car': self.cyl2car,
28 'cyl2sph': lambda a, b, c: self.car2sph(*self.cyl2car(a, b, c)),
29 'cyl2cyl': self.nochange,
30 }
32 # @log
33 def __call__(self, u:Tensor, v:Tensor, w:Tensor, original_coord: Literal['car', 'sph', 'cyl'], target_coord: Literal['car', 'sph', 'cyl']) -> tuple[Tensor, Tensor, Tensor]:
34 return self.convert_func[f'{original_coord}2{target_coord}'](u, v, w)
36 @staticmethod
37 def nochange(u:Tensor, v:Tensor, w:Tensor) -> tuple[Tensor, Tensor, Tensor]:
38 return u, v, w
39 @staticmethod
40 def car2sph(x:Tensor, y:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]:
41 r = torch.sqrt(x**2 + y**2 + z**2)
42 phi = torch.arccos(z/r) # when r=0, output phi=nan
43 phi = torch.nan_to_num(phi, nan=0.) # convert nan to 0
44 theta = torch.arctan2(y, x) # range [-pi, pi]
45 theta = torch.where(theta<0, theta+2*torch.pi, theta) # convert range to [0, 2pi]
46 return r, theta, phi
47 @staticmethod
48 def car2cyl(x:Tensor, y:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]:
49 rho = torch.sqrt(x**2+y**2)
50 theta = torch.arctan2(y, x) # range [-pi, pi]
51 theta = theta + (1-torch.sign(torch.sign(theta)+1))*2*torch.pi # convert range to [0, 2pi]
52 return rho, theta, z
53 @staticmethod
54 def sph2car(r:Tensor, theta:Tensor, phi:Tensor) -> tuple[Tensor, Tensor, Tensor]:
55 sinphi = torch.sin(phi)
56 x = r * torch.cos(theta) * sinphi
57 y = r * torch.sin(theta) * sinphi
58 z = r * torch.cos(phi)
59 return x, y, z
60 @staticmethod
61 def cyl2car(rho:Tensor, theta:Tensor, z:Tensor) -> tuple[Tensor, Tensor, Tensor]:
62 x = rho * torch.cos(theta)
63 y = rho * torch.sin(theta)
64 return x, y, z
67# use just like a function
68convert_coord = __CoordConverter()
71@log
72def moller_trumbore_intersect_count(points: Tensor, ray: Tensor, triangles: Tensor) -> Tensor:
73 """Calculate all the points intersect with all triangles seperately
74 using Möller-Trumbore intersection algorithm.
75 See paper https://doi.org/10.1080/10867651.1997.10487468
76 All variable names follow this paper.
78 Args:
79 points (Tensor): shape=(n1, ..., ni, 3), 3D meshgrid points coordinates
80 ray (Tensor): shape=(3,)
81 triangles (Tensor): shape=(m, 3, 3), m triangles with 3 vertices, each (3,) coordinates
83 Returns:
84 Tensor: shape=(n1, ..., ni), indicate intersect counts per point
85 """
86 #* Highest performance for now, especially on GPU.
87 #* No using python loops.
89 O = points # (n1, ..., ni, 3)
90 D = ray # (3,)
91 V0 = triangles[:,0,:] # (m, 3)
92 E1 = triangles[:,1,:] - V0 # (m, 3)
93 E2 = triangles[:,2,:] - V0 # (m, 3)
95 points_coord_dim = (1,) * (O.dim()-1) # to support any input shape
97 T = O.unsqueeze(-2) - V0.view(*points_coord_dim, *V0.shape) # (n1, ..., ni, m, 3)
98 P = torch.linalg.cross(D.unsqueeze(0), E2, dim=-1) # (m, 3)
99 Q = torch.linalg.cross(T, E1.view(*points_coord_dim, *E1.shape), dim=-1) # (n1, ..., ni, m, 3)
101 PE1_reciprocal = 1 / torch.linalg.vecdot(P, E1, dim=-1) # (m,)
102 QE2 = torch.linalg.vecdot(Q, E2.view(*points_coord_dim, *E2.shape), dim=-1) # (n1, ..., ni, m)
103 PT = torch.linalg.vecdot(P.view(*points_coord_dim, *P.shape), T, dim=-1) # (n1, ..., ni, m)
104 QD = torch.linalg.vecdot(Q, D, dim=-1) # (n1, ..., ni, m)
106 PE1_reciprocal = PE1_reciprocal.view(*points_coord_dim, *PE1_reciprocal.shape) # (1, 1, 1, m)
107 t = PE1_reciprocal * QE2 # (n1, ..., ni, m)
108 u = PE1_reciprocal * PT # (n1, ..., ni, m)
109 v = PE1_reciprocal * QD # (n1, ..., ni, m)
111 intersect = torch.zeros_like(t, dtype=torch.int32)
112 intersect[(t>0) & (u>0) & (v>0) & ((u+v)<1)] = 1 # (n1, ..., ni, m)
113 intersect_count = intersect.sum(-1) # (n1, ..., ni)
114 return intersect_count
118@log
119def complex_increase_argument(complex: Tensor, argument_addend: Tensor|float) -> Tensor:
120 """Args:
121 complex (Tensor): complex dtype tensor
122 argument_addend (Tensor): same shape as complex, or a float apply to all
124 Returns:
125 Tensor: complex dtype tensor
126 """
127 mod, arg = torch.sqrt(complex.real**2 + complex.imag**2), torch.arctan2(complex.imag, complex.real)
128 arg = arg + argument_addend
129 return mod * torch.complex(torch.cos(arg), torch.sin(arg))
132# @log
133# def nearest_interp(x:Tensor, y:Tensor, z:Tensor, px:Tensor, py:Tensor, pz:Tensor, c:Tensor, d:float | Tensor) -> Tensor:
134# """Conduct nearest interpolate on equally spaced meshgrid.
135# 当网格值c是复数时等效于对实部和虚部分别进行插值
137# Args:
138# x (Tensor): any shape, x coordinates of points to be interpolated
139# y (Tensor): any shape, y coordinates of points to be interpolated
140# z (Tensor): any shape, z coordinates of points to be interpolated
141# px (Tensor): shape=(m1,), x1d grid of meshgrid with known values
142# py (Tensor): shape=(m2,), y1d grid of meshgrid with known values
143# pz (Tensor): shape=(m3,), z1d grid of meshgrid with known values
144# c (Tensor): shape=(m1, m2, m3), values of each of in meshgrid(px, py, pz)
145# d (float | Tensor): spacing of meshgrid(px, py, pz), equally spaced
147# Returns:
148# Tensor: same shape as x|y|z, interpolated values of (x, y, z)
149# """
150# ix, iy, iz = (x-px[0]+d/2)/d, (y-py[0]+d/2)/d, (z-pz[0]+d/2)/d
151# ix, iy, iz = ix.to(torch.int64), iy.to(torch.int64), iz.to(torch.int64) # tensors used as indices must be long, byte or bool tensors
152# c_interp = c[ix, iy, iz]
153# return c_interp
155@log
156def trilinear_interp(x:Tensor, y:Tensor, z:Tensor, px:Tensor, py:Tensor, pz:Tensor, c:Tensor, d:float | Tensor) -> Tensor:
157 """Conduct trilinear interpolate on equally spaced meshgrid.
158 当网格值c是复数时等效于对实部和虚部分别进行插值
160 Args:
161 x (Tensor): any shape, x coordinates of points to be interpolated
162 y (Tensor): any shape, y coordinates of points to be interpolated
163 z (Tensor): any shape, z coordinates of points to be interpolated
164 px (Tensor): shape=(m1,), x1d grid of meshgrid with known values
165 py (Tensor): shape=(m2,), y1d grid of meshgrid with known values
166 pz (Tensor): shape=(m3,), z1d grid of meshgrid with known values
167 c (Tensor): shape=(m1, m2, m3), values of each of in meshgrid(px, py, pz)
168 d (float | Tensor): spacing of meshgrid(px, py, pz), equally spaced
170 Returns:
171 Tensor: same shape as x|y|z, interpolated values of (x, y, z)
172 """
173 ix, iy, iz = (x-px[0])/d, (y-py[0])/d, (z-pz[0])/d
174 ix, iy, iz = ix.to(torch.int64), iy.to(torch.int64), iz.to(torch.int64) # tensors used as indices must be long, byte or bool tensors
176 x0, y0, z0 = px[ix], py[iy], pz[iz]
177 x1, y1, z1 = px[ix+1], py[iy+1], pz[iz+1]
178 xd, yd, zd = (x-x0)/(x1-x0), (y-y0)/(y1-y0), (z-z0)/(z1-z0)
180 c_interp = c[ix, iy, iz]*(1-xd)*(1-yd)*(1-zd) \
181 + c[ix+1, iy, iz]*xd*(1-yd)*(1-zd) \
182 + c[ix, iy+1, iz]*(1-xd)*yd*(1-zd) \
183 + c[ix, iy, iz+1]*(1-xd)*(1-yd)*zd \
184 + c[ix+1, iy, iz+1]*xd*(1-yd)*zd \
185 + c[ix, iy+1, iz+1]*(1-xd)*yd*zd \
186 + c[ix+1, iy+1, iz]*xd*yd*(1-zd) \
187 + c[ix+1, iy+1, iz+1]*xd*yd*zd
188 return c_interp
191@log
192def euler_rodrigues_rotate(x: Tensor, y: Tensor, z: Tensor, v_axis: tuple[float, float, float], angle: float) -> tuple[Tensor, Tensor, Tensor]:
193 """Central rotation of coordinates by Euler-Rodrigues formula.
194 Refer to https://en.wikipedia.org/wiki/Euler%E2%80%93Rodrigues_formula
196 Args:
197 x (Tensor): any shape, x coordinates of points to be rotated
198 y (Tensor): any shape, y coordinates of points to be rotated
199 z (Tensor): any shape, z coordinates of points to be rotated
200 v_axis (tuple[float, float, float]): axis of rotation
201 angle (float): in radian
203 Returns:
204 tuple[Tensor, Tensor, Tensor]: rotated coordinated rx, ry, yz, same shape and device as input.
205 """
206 l = math.dist(v_axis, (0,0,0))
208 a = math.cos(angle/2)
209 b = v_axis[0]/l * math.sin(angle/2) # axis vector should be unit vector
210 c = v_axis[1]/l * math.sin(angle/2)
211 d = v_axis[2]/l * math.sin(angle/2)
213 # below is faster than vector formulation by my test
214 rx = (a**2 + b**2 - c**2 - d**2) * x \
215 +(2*(b*c - a*d)) * y \
216 +(2*(b*d + a*c)) * z
218 ry = (2*(b*c + a*d)) * x \
219 +(a**2 + c**2 - b**2 - d**2) * y \
220 +(2*(c*d - a*b)) * z
222 rz = (2*(b*d - a*c)) * x \
223 +(2*(c*d + a*b)) * y \
224 +(a**2 + d**2 - b**2 - c**2) * z
226 return rx, ry, rz
229@log
230def multiple_spherical_sampling(Rs: Tensor, Ns: Tensor) -> tuple[Tensor, Tensor, Tensor]:
231 """Generate sampling points using fibonacci grid
232 for multiple spherical shells. r and N are 1D tensor
233 with same shape.
235 Args:
236 r (Tensor): radius of spherical shells
237 N (Tensor): number of points on each shell
239 Returns:
240 tuple[Tensor, Tensor, Tensor]: x, y, z coordinates
241 """
243 # for better performance:
244 # lesser loop, lesser work inside loop
245 def gen_nNr(R, N):
246 n = torch.arange(1, N+1, dtype=PRECISION)
247 N = torch.full_like(n, N)
248 r = torch.full_like(n, R)
249 return n, N, r
250 nNr = [gen_nNr(R, N) for R, N in zip(Rs, Ns)]
251 n = torch.cat([t[0] for t in nNr])
252 N = torch.cat([t[1] for t in nNr])
253 r = torch.cat([t[2] for t in nNr])
255 phi = (torch.sqrt(torch.tensor(5, dtype=PRECISION))-1)/2
256 z = ((2*n-1)/N - 1)
257 x = r * torch.sqrt(1-z**2) * torch.cos(2*torch.pi*n*phi)
258 y = r * torch.sqrt(1-z**2) * torch.sin(2*torch.pi*n*phi)
259 z = r * z
260 return x, y, z