Coverage for src\model2sas\plot.py: 67%

92 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:26 +0800

1"""All kinds of plot, from 1d to 3d. 

2""" 

3 

4from pathlib import Path 

5from typing import Literal, Optional, Sequence 

6 

7import torch 

8from torch import Tensor 

9 

10import plotly.graph_objects as go 

11 

12 

13class Voxel(go.Mesh3d): 

14 def __init__(self, xc=None, yc=None, zc=None, spacing=None, **kwargs): 

15 x, y, z, i, j, k = self.gen_vertices_triangles(xc, yc, zc, spacing) 

16 super().__init__(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs) 

17 

18 def gen_vertices_triangles(self, xc, yc, zc, spacing) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 

19 """Generate vertices and triangles for mesh plot. 

20 For each point (xc, yc, zc), generate a cubic box with edge_length = spacing 

21 

22 Args: 

23 xc: x coordinates of center point 

24 yc: y coordinates of center point 

25 zc: z coordinates of center point 

26 spacing: spacing of mesh grid, and cubic box edge length 

27 

28 Returns: 

29 tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: _description_ 

30 """ 

31 s = spacing 

32 xv = torch.stack((xc-s/2, xc+s/2, xc+s/2, xc-s/2, xc-s/2, xc+s/2, xc+s/2, xc-s/2), dim=1).flatten() 

33 yv = torch.stack((yc-s/2, yc-s/2, yc+s/2, yc+s/2, yc-s/2, yc-s/2, yc+s/2, yc+s/2), dim=1).flatten() 

34 zv = torch.stack((zc-s/2, zc-s/2, zc-s/2, zc-s/2, zc+s/2, zc+s/2, zc+s/2, zc+s/2), dim=1).flatten() 

35 

36 i0 = torch.tensor((0, 2, 0, 5, 1, 6, 2, 7, 3, 4, 4, 6)) 

37 j0 = torch.tensor((1, 3, 1, 4, 2, 5, 3, 6, 0, 7, 5, 7)) 

38 k0 = torch.tensor((2, 0, 5, 0, 6, 1, 7, 2, 4, 3, 6, 4)) 

39 seq = 8 * torch.arange(xc.numel()) 

40 i = (torch.unsqueeze(i0, 0) + torch.unsqueeze(seq, 1)).flatten() 

41 j = (torch.unsqueeze(j0, 0) + torch.unsqueeze(seq, 1)).flatten() 

42 k = (torch.unsqueeze(k0, 0) + torch.unsqueeze(seq, 1)).flatten() 

43 return xv, yv, zv, i, j, k 

44 

45 

46class Figure(go.Figure): 

47 

48 def __init__(self, data=None, layout=None, frames=None, skip_invalid=False, **kwargs): 

49 super().__init__(data, layout, frames, skip_invalid, **kwargs) 

50 

51 def write_html(self, filename: str, *args, **kwargs): 

52 Path(filename).write_text(self.to_html(), encoding='utf-8') 

53 

54 def set_title_text(self, text: str): 

55 self.update_layout(title_text=text) 

56 

57 def set_template(self, plotly_template: str): 

58 self.update_layout(template=plotly_template) 

59 

60 @staticmethod 

61 def __ensure_cpu_tensor(*t): 

62 if len(t) == 1: 

63 return t[0].cpu() if isinstance(t[0], Tensor) else t[0] 

64 else: 

65 return tuple(ti.cpu() if isinstance(ti, Tensor) else ti for ti in t) 

66 

67 

68 def plot_curve1d(self, x: Sequence|Tensor, y: Sequence|Tensor, name: Optional[str] = None, mode: Optional[Literal['lines', 'markers', 'lines+markers']] = None, logx: bool = True, logy: bool = True) -> None: 

69 x, y = self.__ensure_cpu_tensor(x, y) 

70 self.add_trace(go.Scatter(x=x, y=y, name=name, mode=mode)) 

71 if logx: 

72 self.update_xaxes(type='log') 

73 if logy: 

74 self.update_yaxes(type='log') 

75 

76 

77 def plot_surface2d(self, data2d: Tensor, log_value: bool = True, colorscale: Optional[str] = None) -> None: 

78 data2d = self.__ensure_cpu_tensor(data2d) 

79 if log_value: 

80 data2d = torch.log10(data2d).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output 

81 colorbar_title = 'log value' 

82 else: 

83 colorbar_title = None 

84 self.add_trace(go.Heatmap( 

85 z=data2d.T, 

86 colorscale=colorscale, 

87 colorbar={'title': colorbar_title} 

88 )) 

89 self.update_xaxes( 

90 scaleanchor='y', 

91 scaleratio=1, 

92 constrain='domain' 

93 ) 

94 

95 @staticmethod 

96 def __surfacecolor(coord: Tensor, value: Tensor|float|int|None, log_value: bool = True) -> Tensor: 

97 value = 1 if value is None else value 

98 if isinstance(value, (float, int)): 

99 surfacecolor = torch.full_like(coord, value) 

100 else: 

101 if log_value: 

102 surfacecolor = torch.log10(value).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output 

103 else: 

104 surfacecolor = value 

105 return surfacecolor 

106 

107 def plot_surface3d(self, x: Tensor, y: Tensor, z: Tensor, value: Optional[Tensor|float|int] = None, log_value: bool = True, colorscale: Optional[str] = None) -> None: 

108 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value) 

109 surfacecolor = self.__surfacecolor(x, value, log_value) 

110 self.add_trace(go.Surface( 

111 x=x, y=y, z=z, surfacecolor=surfacecolor, coloraxis='coloraxis' 

112 )) 

113 self.update_layout(coloraxis = {'colorscale': colorscale}) 

114 self.update_layout(scene_aspectmode='data') # make equal aspect, or use fig.update_scenes(aspectmode='data')  

115 

116 

117 def plot_volume3d(self, x: Tensor, y: Tensor, z: Tensor, value: Tensor, log_value: bool = False, opacity: float = 0.1, surface_count: int = 21, colorscale: Optional[str] = None) -> None: 

118 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value) 

119 if log_value: 

120 value = torch.log10(value).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output 

121 colorbar_title = 'log value' 

122 else: 

123 colorbar_title = None 

124 self.add_trace(go.Volume( 

125 x=x.flatten(), 

126 y=y.flatten(), 

127 z=z.flatten(), 

128 value=value.flatten(), 

129 opacity=opacity, 

130 surface_count=surface_count, 

131 coloraxis='coloraxis', 

132 colorbar={'title': colorbar_title} 

133 )) 

134 self.update_layout(scene_aspectmode='data') # make equal aspect 

135 self.update_layout(coloraxis={'colorscale': colorscale}) 

136 

137 def plot_voxel3d(self, x: Tensor, y: Tensor, z: Tensor, spacing: float, name: Optional[str] = None, showlegend: bool = True) -> None: 

138 x, y, z = self.__ensure_cpu_tensor(x, y, z) 

139 self.add_trace(Voxel( 

140 xc=x, 

141 yc=y, 

142 zc=z, 

143 spacing=spacing, 

144 name=name, 

145 showlegend=showlegend, 

146 )) 

147 self.update_layout(scene_aspectmode='data') # make equal aspect 

148 

149 

150 def plot_detector(self, x: Tensor, y: Tensor, z: Tensor, value: Optional[Tensor] = None, log_value: bool = True, colorscale: Optional[str] = None) -> None: 

151 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value) 

152 self.plot_surface3d(x, y, z, value=value, log_value=log_value, colorscale=colorscale) 

153 

154 # plot origin 

155 self.add_scatter3d(x=[0,], y=[0,], z=[0,], mode='markers', showlegend=False) 

156 

157 # plot direct beam in y axis direction 

158 self.add_scatter3d(x=[0,0], y=[0,y.max().item()], z=[0,0], mode='lines', showlegend=False) 

159 

160 # plot light edges on detector 

161 v0 = torch.tensor([0., 0., 0.]) 

162 v1 = torch.tensor((x[0,0], y[0,0], z[0,0])) 

163 v2 = torch.tensor((x[0,-1], y[0,-1], z[0,-1])) 

164 v3 = torch.tensor((x[-1,-1], y[-1,-1], z[-1,-1])) 

165 v4 = torch.tensor((x[-1,0], y[-1,0], z[-1,0])) 

166 ex, ey, ez = torch.unbind( 

167 torch.stack([v0,v1,v2,v3,v4], dim=1), 

168 dim=0 

169 ) 

170 self.add_mesh3d( 

171 x=ex, 

172 y=ey, 

173 z=ez, 

174 alphahull=0, 

175 color='gray', 

176 opacity=0.1, 

177 ) 

178 self.update_layout(scene_aspectmode='data')