Coverage for src/duelboard/visualization.py: 25%

68 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 19:18 +0900

1"""Visualization utilities for Elo ratings and battle data.""" 

2 

3from __future__ import annotations 

4 

5from typing import TYPE_CHECKING 

6 

7import pandas as pd 

8 

9if TYPE_CHECKING: 

10 from plotly.graph_objects import Figure 

11 

12 from .types import RatingsDict 

13 

14try: 

15 import plotly.express as px 

16 import plotly.graph_objects as go # noqa: F401 

17 PLOTLY_AVAILABLE = True 

18except ImportError: 

19 PLOTLY_AVAILABLE = False 

20 

21 

22def _check_plotly() -> None: 

23 """Check if plotly is available.""" 

24 if not PLOTLY_AVAILABLE: 

25 msg = ( 

26 "plotly is required for visualization. " 

27 "Install it with: pip install duelboard[visualization]" 

28 ) 

29 raise ImportError(msg) 

30 

31 

32def plot_leaderboard( 

33 ratings: RatingsDict, 

34 title: str = "Elo Ratings Leaderboard", 

35 *, 

36 show_confidence_intervals: bool = True, 

37 height: int = 600, 

38 **kwargs: object, 

39) -> Figure: 

40 """Plot Elo ratings leaderboard. 

41 

42 Args: 

43 ratings: Dictionary of player ratings 

44 title: Plot title 

45 show_confidence_intervals: Whether to show confidence intervals if available 

46 height: Plot height 

47 **kwargs: Additional arguments passed to plotly 

48 

49 Returns: 

50 Plotly figure object 

51 """ 

52 _check_plotly() 

53 

54 # Convert to DataFrame and sort 

55 data = [] 

56 for player, rating in ratings.items(): 

57 row = { 

58 "player": player, 

59 "rating": rating.rating, 

60 "battles": rating.battles, 

61 } 

62 

63 if rating.confidence_interval and show_confidence_intervals: 

64 row["lower"] = rating.confidence_interval[0] 

65 row["upper"] = rating.confidence_interval[1] 

66 row["error_minus"] = rating.rating - rating.confidence_interval[0] 

67 row["error_plus"] = rating.confidence_interval[1] - rating.rating 

68 

69 data.append(row) 

70 

71 df = pd.DataFrame(data).sort_values("rating", ascending=False) 

72 

73 if "error_minus" in df.columns and show_confidence_intervals: 

74 # Plot with error bars 

75 fig = px.scatter( 

76 df, 

77 x="player", 

78 y="rating", 

79 error_y="error_plus", 

80 error_y_minus="error_minus", 

81 text="rating", 

82 title=title, 

83 height=height, 

84 **kwargs, 

85 ) 

86 fig.update_traces(texttemplate="%{text:.0f}", textposition="top center") 

87 else: 

88 # Plot without error bars 

89 fig = px.bar( 

90 df, 

91 x="player", 

92 y="rating", 

93 text="rating", 

94 title=title, 

95 height=height, 

96 **kwargs, 

97 ) 

98 fig.update_traces(texttemplate="%{text:.0f}", textposition="outside") 

99 

100 fig.update_layout( 

101 xaxis_title="Player", 

102 yaxis_title="Elo Rating", 

103 showlegend=False, 

104 ) 

105 

106 return fig 

107 

108 

109def plot_win_rate_matrix( 

110 win_rate_matrix: pd.DataFrame, 

111 title: str = "Predicted Win Rate Matrix", 

112 height: int = 600, 

113 width: int = 600, 

114 **kwargs: object, 

115) -> Figure: 

116 """Plot win rate prediction matrix as heatmap. 

117 

118 Args: 

119 win_rate_matrix: Win rate matrix from WinRatePredictor 

120 title: Plot title 

121 height: Plot height 

122 width: Plot width 

123 **kwargs: Additional arguments passed to plotly 

124 

125 Returns: 

126 Plotly figure object 

127 """ 

128 _check_plotly() 

129 

130 # Sort by mean win rate 

131 ordered_players = win_rate_matrix.mean(axis=1).sort_values(ascending=False).index 

132 matrix = win_rate_matrix.loc[ordered_players, ordered_players] 

133 

134 fig = px.imshow( 

135 matrix, 

136 color_continuous_scale="RdBu", 

137 text_auto=".2f", 

138 title=title, 

139 height=height, 

140 width=width, 

141 **kwargs, 

142 ) 

143 

144 fig.update_layout( 

145 xaxis_title="Player B: Opponent", 

146 yaxis_title="Player A: Focal Player", 

147 xaxis_side="top", 

148 title_y=0.07, 

149 title_x=0.5, 

150 ) 

151 

152 fig.update_traces( 

153 hovertemplate="Player A: %{y}<br>Player B: %{x}<br>Win Rate: %{z}<extra></extra>", 

154 ) 

155 

156 return fig 

157 

158 

159def plot_battle_count_matrix( 

160 battle_count_matrix: pd.DataFrame, 

161 title: str = "Battle Count Matrix", 

162 height: int = 600, 

163 width: int = 600, 

164 **kwargs: object, 

165) -> Figure: 

166 """Plot battle count matrix as heatmap. 

167 

168 Args: 

169 battle_count_matrix: Battle count matrix 

170 title: Plot title 

171 height: Plot height 

172 width: Plot width 

173 **kwargs: Additional arguments passed to plotly 

174 

175 Returns: 

176 Plotly figure object 

177 """ 

178 _check_plotly() 

179 

180 # Sort by total battle count 

181 ordering = battle_count_matrix.sum().sort_values(ascending=False).index 

182 matrix = battle_count_matrix.loc[ordering, ordering] 

183 

184 fig = px.imshow( 

185 matrix, 

186 text_auto=True, 

187 title=title, 

188 height=height, 

189 width=width, 

190 **kwargs, 

191 ) 

192 

193 fig.update_layout( 

194 xaxis_title="Player B", 

195 yaxis_title="Player A", 

196 xaxis_side="top", 

197 title_y=0.07, 

198 title_x=0.5, 

199 ) 

200 

201 fig.update_traces( 

202 hovertemplate="Player A: %{y}<br>Player B: %{x}<br>Count: %{z}<extra></extra>", 

203 ) 

204 

205 return fig 

206 

207 

208def plot_bootstrap_distributions( 

209 bootstrap_results: pd.DataFrame, 

210 title: str = "Bootstrap Rating Distributions", 

211 height: int = 400, 

212 **kwargs: object, 

213) -> Figure: 

214 """Plot bootstrap rating distributions as violin plot. 

215 

216 Args: 

217 bootstrap_results: DataFrame with bootstrap results 

218 title: Plot title 

219 height: Plot height 

220 **kwargs: Additional arguments passed to plotly 

221 

222 Returns: 

223 Plotly figure object 

224 """ 

225 _check_plotly() 

226 

227 # Melt DataFrame for violin plot 

228 melted = bootstrap_results.melt(var_name="player", value_name="rating") 

229 

230 fig = px.violin( 

231 melted, 

232 x="player", 

233 y="rating", 

234 title=title, 

235 height=height, 

236 **kwargs, 

237 ) 

238 

239 fig.update_layout( 

240 xaxis_title="Player", 

241 yaxis_title="Rating", 

242 showlegend=False, 

243 ) 

244 

245 return fig 

246 

247 

248def plot_battle_outcome_distribution( 

249 battles: pd.DataFrame, 

250 title: str = "Battle Outcome Distribution", 

251 height: int = 400, 

252 **kwargs: object, 

253) -> Figure: 

254 """Plot distribution of battle outcomes. 

255 

256 Args: 

257 battles: DataFrame with battle data 

258 title: Plot title 

259 height: Plot height 

260 **kwargs: Additional arguments passed to plotly 

261 

262 Returns: 

263 Plotly figure object 

264 """ 

265 _check_plotly() 

266 

267 outcome_counts = battles["winner"].value_counts() 

268 

269 fig = px.bar( 

270 x=outcome_counts.index, 

271 y=outcome_counts.values, 

272 title=title, 

273 text_auto=True, 

274 height=height, 

275 **kwargs, 

276 ) 

277 

278 fig.update_layout( 

279 xaxis_title="Battle Outcome", 

280 yaxis_title="Count", 

281 showlegend=False, 

282 ) 

283 

284 return fig 

285 

286 

287def plot_player_battle_frequency( 

288 battles: pd.DataFrame, 

289 title: str = "Player Battle Frequency", 

290 top_k: int = 20, 

291 height: int = 500, 

292 **kwargs: object, 

293) -> Figure: 

294 """Plot frequency of battles for each player. 

295 

296 Args: 

297 battles: DataFrame with battle data 

298 title: Plot title 

299 top_k: Number of top players to show 

300 height: Plot height 

301 **kwargs: Additional arguments passed to plotly 

302 

303 Returns: 

304 Plotly figure object 

305 """ 

306 _check_plotly() 

307 

308 # Count battles per player 

309 player_counts = pd.concat([battles["player_a"], battles["player_b"]]).value_counts() 

310 

311 if top_k: 

312 player_counts = player_counts.head(top_k) 

313 

314 fig = px.bar( 

315 x=player_counts.index, 

316 y=player_counts.values, 

317 title=title, 

318 text_auto=True, 

319 height=height, 

320 **kwargs, 

321 ) 

322 

323 fig.update_layout( 

324 xaxis_title="Player", 

325 yaxis_title="Battle Count", 

326 showlegend=False, 

327 ) 

328 

329 return fig