Coverage for src\model2sas\plot.py: 67%
92 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:26 +0800
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:26 +0800
1"""All kinds of plot, from 1d to 3d.
2"""
4from pathlib import Path
5from typing import Literal, Optional, Sequence
7import torch
8from torch import Tensor
10import plotly.graph_objects as go
13class Voxel(go.Mesh3d):
14 def __init__(self, xc=None, yc=None, zc=None, spacing=None, **kwargs):
15 x, y, z, i, j, k = self.gen_vertices_triangles(xc, yc, zc, spacing)
16 super().__init__(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs)
18 def gen_vertices_triangles(self, xc, yc, zc, spacing) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
19 """Generate vertices and triangles for mesh plot.
20 For each point (xc, yc, zc), generate a cubic box with edge_length = spacing
22 Args:
23 xc: x coordinates of center point
24 yc: y coordinates of center point
25 zc: z coordinates of center point
26 spacing: spacing of mesh grid, and cubic box edge length
28 Returns:
29 tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: _description_
30 """
31 s = spacing
32 xv = torch.stack((xc-s/2, xc+s/2, xc+s/2, xc-s/2, xc-s/2, xc+s/2, xc+s/2, xc-s/2), dim=1).flatten()
33 yv = torch.stack((yc-s/2, yc-s/2, yc+s/2, yc+s/2, yc-s/2, yc-s/2, yc+s/2, yc+s/2), dim=1).flatten()
34 zv = torch.stack((zc-s/2, zc-s/2, zc-s/2, zc-s/2, zc+s/2, zc+s/2, zc+s/2, zc+s/2), dim=1).flatten()
36 i0 = torch.tensor((0, 2, 0, 5, 1, 6, 2, 7, 3, 4, 4, 6))
37 j0 = torch.tensor((1, 3, 1, 4, 2, 5, 3, 6, 0, 7, 5, 7))
38 k0 = torch.tensor((2, 0, 5, 0, 6, 1, 7, 2, 4, 3, 6, 4))
39 seq = 8 * torch.arange(xc.numel())
40 i = (torch.unsqueeze(i0, 0) + torch.unsqueeze(seq, 1)).flatten()
41 j = (torch.unsqueeze(j0, 0) + torch.unsqueeze(seq, 1)).flatten()
42 k = (torch.unsqueeze(k0, 0) + torch.unsqueeze(seq, 1)).flatten()
43 return xv, yv, zv, i, j, k
46class Figure(go.Figure):
48 def __init__(self, data=None, layout=None, frames=None, skip_invalid=False, **kwargs):
49 super().__init__(data, layout, frames, skip_invalid, **kwargs)
51 def write_html(self, filename: str, *args, **kwargs):
52 Path(filename).write_text(self.to_html(), encoding='utf-8')
54 def set_title_text(self, text: str):
55 self.update_layout(title_text=text)
57 def set_template(self, plotly_template: str):
58 self.update_layout(template=plotly_template)
60 @staticmethod
61 def __ensure_cpu_tensor(*t):
62 if len(t) == 1:
63 return t[0].cpu() if isinstance(t[0], Tensor) else t[0]
64 else:
65 return tuple(ti.cpu() if isinstance(ti, Tensor) else ti for ti in t)
68 def plot_curve1d(self, x: Sequence|Tensor, y: Sequence|Tensor, name: Optional[str] = None, mode: Optional[Literal['lines', 'markers', 'lines+markers']] = None, logx: bool = True, logy: bool = True) -> None:
69 x, y = self.__ensure_cpu_tensor(x, y)
70 self.add_trace(go.Scatter(x=x, y=y, name=name, mode=mode))
71 if logx:
72 self.update_xaxes(type='log')
73 if logy:
74 self.update_yaxes(type='log')
77 def plot_surface2d(self, data2d: Tensor, log_value: bool = True, colorscale: Optional[str] = None) -> None:
78 data2d = self.__ensure_cpu_tensor(data2d)
79 if log_value:
80 data2d = torch.log10(data2d).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output
81 colorbar_title = 'log value'
82 else:
83 colorbar_title = None
84 self.add_trace(go.Heatmap(
85 z=data2d.T,
86 colorscale=colorscale,
87 colorbar={'title': colorbar_title}
88 ))
89 self.update_xaxes(
90 scaleanchor='y',
91 scaleratio=1,
92 constrain='domain'
93 )
95 @staticmethod
96 def __surfacecolor(coord: Tensor, value: Tensor|float|int|None, log_value: bool = True) -> Tensor:
97 value = 1 if value is None else value
98 if isinstance(value, (float, int)):
99 surfacecolor = torch.full_like(coord, value)
100 else:
101 if log_value:
102 surfacecolor = torch.log10(value).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output
103 else:
104 surfacecolor = value
105 return surfacecolor
107 def plot_surface3d(self, x: Tensor, y: Tensor, z: Tensor, value: Optional[Tensor|float|int] = None, log_value: bool = True, colorscale: Optional[str] = None) -> None:
108 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value)
109 surfacecolor = self.__surfacecolor(x, value, log_value)
110 self.add_trace(go.Surface(
111 x=x, y=y, z=z, surfacecolor=surfacecolor, coloraxis='coloraxis'
112 ))
113 self.update_layout(coloraxis = {'colorscale': colorscale})
114 self.update_layout(scene_aspectmode='data') # make equal aspect, or use fig.update_scenes(aspectmode='data')
117 def plot_volume3d(self, x: Tensor, y: Tensor, z: Tensor, value: Tensor, log_value: bool = False, opacity: float = 0.1, surface_count: int = 21, colorscale: Optional[str] = None) -> None:
118 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value)
119 if log_value:
120 value = torch.log10(value).nan_to_num(nan=0., neginf=0.) # incase 0 in data, cause log(0) output
121 colorbar_title = 'log value'
122 else:
123 colorbar_title = None
124 self.add_trace(go.Volume(
125 x=x.flatten(),
126 y=y.flatten(),
127 z=z.flatten(),
128 value=value.flatten(),
129 opacity=opacity,
130 surface_count=surface_count,
131 coloraxis='coloraxis',
132 colorbar={'title': colorbar_title}
133 ))
134 self.update_layout(scene_aspectmode='data') # make equal aspect
135 self.update_layout(coloraxis={'colorscale': colorscale})
137 def plot_voxel3d(self, x: Tensor, y: Tensor, z: Tensor, spacing: float, name: Optional[str] = None, showlegend: bool = True) -> None:
138 x, y, z = self.__ensure_cpu_tensor(x, y, z)
139 self.add_trace(Voxel(
140 xc=x,
141 yc=y,
142 zc=z,
143 spacing=spacing,
144 name=name,
145 showlegend=showlegend,
146 ))
147 self.update_layout(scene_aspectmode='data') # make equal aspect
150 def plot_detector(self, x: Tensor, y: Tensor, z: Tensor, value: Optional[Tensor] = None, log_value: bool = True, colorscale: Optional[str] = None) -> None:
151 x, y, z, value = self.__ensure_cpu_tensor(x, y, z, value)
152 self.plot_surface3d(x, y, z, value=value, log_value=log_value, colorscale=colorscale)
154 # plot origin
155 self.add_scatter3d(x=[0,], y=[0,], z=[0,], mode='markers', showlegend=False)
157 # plot direct beam in y axis direction
158 self.add_scatter3d(x=[0,0], y=[0,y.max().item()], z=[0,0], mode='lines', showlegend=False)
160 # plot light edges on detector
161 v0 = torch.tensor([0., 0., 0.])
162 v1 = torch.tensor((x[0,0], y[0,0], z[0,0]))
163 v2 = torch.tensor((x[0,-1], y[0,-1], z[0,-1]))
164 v3 = torch.tensor((x[-1,-1], y[-1,-1], z[-1,-1]))
165 v4 = torch.tensor((x[-1,0], y[-1,0], z[-1,0]))
166 ex, ey, ez = torch.unbind(
167 torch.stack([v0,v1,v2,v3,v4], dim=1),
168 dim=0
169 )
170 self.add_mesh3d(
171 x=ex,
172 y=ey,
173 z=ez,
174 alphahull=0,
175 color='gray',
176 opacity=0.1,
177 )
178 self.update_layout(scene_aspectmode='data')