Coverage for MPT/sankey_gap.py: 72%

162 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-08 18:27 +0200

1# -*- coding: utf-8 -*- 

2r""" 

3Produces simple Sankey Diagrams with matplotlib. 

4@author: Anneya Golob & marcomanz & pierre-sassoulas & jorwoods & vgalisson 

5 .-. 

6 .--.( ).--. 

7 <-. .-.-.(.-> )_ .--. 

8 `-`( )-' `) ) 

9 (o o ) `)`-' 

10 ( ) ,) 

11 ( () ) ) 

12 `---"\ , , ,/` 

13 `--' `--' `--' 

14 | | | | 

15 | | | | 

16 ' | ' | 

17""" 

18 

19# fmt: off 

20import warnings 

21import logging 

22from collections import defaultdict 

23 

24import matplotlib.pyplot as plt 

25import matplotlib.patheffects as path_effects 

26import numpy as np 

27import pandas as pd 

28import seaborn as sns 

29# fmt: on 

30 

31LOGGER = logging.getLogger(__name__) 

32 

33 

34class PySankeyException(Exception): 

35 """Generic PySankey Exception.""" 

36 

37 

38class NullsInFrame(PySankeyException): 

39 pass 

40 

41 

42class LabelMismatch(PySankeyException): 

43 pass 

44 

45 

46def check_data_matches_labels(labels, data, side): 

47 """Check whether or not data matches labels. 

48 

49 Raise a LabelMismatch Exception if not.""" 

50 if len(labels) > 0: 50 ↛ exitline 50 didn't return from function 'check_data_matches_labels' because the condition on line 50 was always true

51 if isinstance(data, list): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 data = set(data) 

53 if isinstance(data, pd.Series): 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true

54 data = set(data.unique().tolist()) 

55 if isinstance(labels, list): 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true

56 labels = set(labels) 

57 if labels != data: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true

58 msg = "\n" 

59 if len(labels) <= 20: 

60 msg = "Labels: " + ",".join(labels) + "\n" 

61 if len(data) < 20: 

62 msg += "Data: " + ",".join(data) 

63 raise LabelMismatch( 

64 "{0} labels and data do not match.{1}".format(side, msg) 

65 ) 

66 

67 

68def sankey( 

69 left, 

70 right, 

71 leftWeight=None, 

72 rightWeight=None, 

73 colorDict=None, 

74 leftLabels=None, 

75 rightLabels=None, 

76 aspect=4, 

77 rightColor=False, 

78 fontsize="medium", 

79 figureName=None, 

80 closePlot=False, 

81 figSize=None, 

82 ax=None, 

83): 

84 """ 

85 Make Sankey Diagram showing flow from left-->right 

86 

87 Inputs: 

88 left = NumPy array of object labels on the left of the diagram 

89 right = NumPy array of corresponding labels on the right of the diagram 

90 len(right) == len(left) 

91 leftWeight = NumPy array of weights for each strip starting from the 

92 left of the diagram, if not specified 1 is assigned 

93 rightWeight = NumPy array of weights for each strip starting from the 

94 right of the diagram, if not specified the corresponding leftWeight 

95 is assigned 

96 colorDict = Dictionary of colors to use for each label 

97 {'label':'color'} 

98 leftLabels = order of the left labels in the diagram 

99 rightLabels = order of the right labels in the diagram 

100 aspect = vertical extent of the diagram in units of horizontal extent 

101 rightColor = If true, each strip in the diagram will be be colored 

102 according to its left label 

103 figSize = tuple setting the width and height of the sankey diagram. 

104 Defaults to current figure size 

105 ax = optional, matplotlib axes to plot on, otherwise uses current axes. 

106 Output: 

107 ax : matplotlib Axes 

108 """ 

109 warn = [] 

110 if figureName is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 msg = "use of figureName in sankey() is deprecated" 

112 warnings.warn(msg, DeprecationWarning) 

113 warn.append(msg[7:-14]) 

114 if closePlot is not False: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 msg = "use of closePlot in sankey() is deprecated" 

116 warnings.warn(msg, DeprecationWarning) 

117 warn.append(msg[7:-14]) 

118 if figSize is not None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 msg = "use of figSize in sankey() is deprecated" 

120 warnings.warn(msg, DeprecationWarning) 

121 warn.append(msg[7:-14]) 

122 

123 if warn: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 LOGGER.warning( 

125 " The following arguments are deprecated and should be removed: %s", 

126 ", ".join(warn), 

127 ) 

128 

129 if ax is None: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true

130 ax = plt.gca() 

131 

132 if leftWeight is None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 leftWeight = [] 

134 if rightWeight is None: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 rightWeight = [] 

136 if leftLabels is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 leftLabels = [] 

138 if rightLabels is None: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true

139 rightLabels = [] 

140 # Check weights 

141 if len(leftWeight) == 0: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 leftWeight = np.ones(len(left)) 

143 

144 if len(rightWeight) == 0: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 rightWeight = leftWeight 

146 

147 # plt.rc("text", usetex=False) 

148 # plt.rc("font", family="serif") 

149 

150 # Create Dataframe 

151 if isinstance(left, pd.Series): 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 left = left.reset_index(drop=True) 

153 if isinstance(right, pd.Series): 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 right = right.reset_index(drop=True) 

155 dataFrame = pd.DataFrame( 

156 { 

157 "left": left, 

158 "right": right, 

159 "leftWeight": leftWeight, 

160 "rightWeight": rightWeight, 

161 }, 

162 index=range(len(left)), 

163 ) 

164 

165 if len(dataFrame[(dataFrame.left.isnull()) | (dataFrame.right.isnull())]): 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true

166 raise NullsInFrame("Sankey graph does not support null values.") 

167 

168 # Identify all labels that appear 'left' or 'right' 

169 allLabels = pd.Series( 

170 np.r_[dataFrame.left.unique(), dataFrame.right.unique()] 

171 ).unique() 

172 LOGGER.debug("Labels to handle : %s", allLabels) 

173 

174 # Identify left labels 

175 if len(leftLabels) == 0: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 leftLabels = pd.Series(dataFrame.left.unique()).unique() 

177 else: 

178 check_data_matches_labels(leftLabels, dataFrame["left"], "left") 

179 

180 # Identify right labels 

181 if len(rightLabels) == 0: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true

182 rightLabels = pd.Series(dataFrame.right.unique()).unique() 

183 else: 

184 check_data_matches_labels(rightLabels, dataFrame["right"], "right") 

185 

186 # If no colorDict given, make one 

187 if colorDict is None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 colorDict = {} 

189 palette = "hls" 

190 colorPalette = sns.color_palette(palette, len(allLabels)) 

191 for i, label in enumerate(allLabels): 

192 colorDict[label] = colorPalette[i] 

193 LOGGER.debug("The colordict value are : %s", colorDict) 

194 

195 # Determine widths of individual strips 

196 ns_l = defaultdict() 

197 ns_r = defaultdict() 

198 for leftLabel in leftLabels: 

199 leftDict = {} 

200 rightDict = {} 

201 for rightLabel in rightLabels: 

202 leftDict[rightLabel] = dataFrame[ 

203 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

204 ].leftWeight.sum() 

205 rightDict[rightLabel] = dataFrame[ 

206 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

207 ].rightWeight.sum() 

208 ns_l[leftLabel] = leftDict 

209 ns_r[leftLabel] = rightDict 

210 

211 # Determine positions of left label patches and total widths 

212 leftWidths, topEdge = _get_positions_and_total_widths(dataFrame, leftLabels, "left") 

213 

214 # Determine positions of right label patches and total widths 

215 rightWidths, topEdge = _get_positions_and_total_widths( 

216 dataFrame, rightLabels, "right" 

217 ) 

218 

219 # Total vertical extent of diagram 

220 xMax = topEdge / aspect 

221 

222 previousleftlabel = "" 

223 # Draw vertical bars on left and right of each label's section & print label 

224 for vall, leftLabel in enumerate(leftLabels): 

225 if vall != 0: 

226 if _draw_label(leftWidths[leftLabel], leftWidths[previousleftlabel]): 

227 continue 

228 ax.text( 

229 -0.05 * xMax, 

230 leftWidths[leftLabel]["bottom"] + 0.5 * leftWidths[leftLabel]["left"], 

231 rf"\textbf{{{leftLabel}}}", 

232 {"ha": "right", "va": "center"}, 

233 fontsize=fontsize, 

234 zorder=2, 

235 ) 

236 previousleftlabel = leftLabel 

237 previousrightlabel = "" 

238 for valr, rightLabel in enumerate(rightLabels): 

239 if valr != 0: 

240 if _draw_label(rightWidths[rightLabel], rightWidths[previousrightlabel]): 

241 continue 

242 ax.text( 

243 1.05 * xMax, 

244 rightWidths[rightLabel]["bottom"] + 0.5 * rightWidths[rightLabel]["right"], 

245 rf"\textbf{{{rightLabel}}}", 

246 {"ha": "left", "va": "center"}, 

247 fontsize=fontsize, 

248 zorder=2, 

249 ) 

250 previousrightlabel = rightLabel 

251 

252 ymin, ymax = None, None 

253 # Plot strips 

254 for vall, leftLabel in enumerate(leftLabels): 

255 for valr, rightLabel in enumerate(rightLabels): 

256 labelColor = leftLabel 

257 if rightColor: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true

258 labelColor = rightLabel 

259 if ( 

260 len( 

261 dataFrame[ 

262 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

263 ] 

264 ) 

265 > 0 

266 ): 

267 # Create array of y values for each strip, half at left value, 

268 # half at right, convolve 

269 ys_d = np.array( 

270 50 * [leftWidths[leftLabel]["bottom"]] 

271 + 50 * [rightWidths[rightLabel]["bottom"]] 

272 ) 

273 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") 

274 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") 

275 ys_u = np.array( 

276 50 * [leftWidths[leftLabel]["bottom"] + ns_l[leftLabel][rightLabel]] 

277 + 50 

278 * [rightWidths[rightLabel]["bottom"] + ns_r[leftLabel][rightLabel]] 

279 ) 

280 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") 

281 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") 

282 

283 yrange = np.subtract(*ax.get_ylim()[::-1]) 

284 relative_width = np.mean(ys_u - ys_d) / yrange 

285 

286 # Update bottom edges at each label so next strip starts at the right place 

287 leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel] 

288 rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel] 

289 ax.fill_between( 

290 np.linspace(-0.013, 1.013, len(ys_d)) * xMax, 

291 ys_d, 

292 ys_u, 

293 alpha=1.0, 

294 zorder=1, 

295 facecolor=colorDict[labelColor], 

296 ) 

297 

298 if ymin is None: 

299 ymin = ys_d.min() 

300 ymin = np.min([ys_d.min(), ys_u.min(), ymin]) 

301 

302 if ymax is None: 

303 ymax = ys_d.max() 

304 ymax = np.max([ys_d.max(), ys_u.max(), ymax]) 

305 

306 ax.set_ylim(ymin, ymax) 

307 ax.axis("off") 

308 

309 if figSize is not None: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 plt.gcf().set_size_inches(figSize) 

311 

312 if figureName is not None: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 fileName = "{}.png".format(figureName) 

314 plt.savefig(fileName, bbox_inches="tight", dpi=150) 

315 LOGGER.info("Sankey diagram generated in '%s'", fileName) 

316 if closePlot: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true

317 plt.close() 

318 

319 return ax 

320 

321 

322def _get_positions_and_total_widths(df, labels, side): 

323 """Determine positions of label patches and total widths""" 

324 # add gap 

325 gap = 50000 

326 # print(f'gap : {gap}') 

327 widths = defaultdict() 

328 for i, label in enumerate(labels): 

329 labelWidths = {} 

330 labelWidths[side] = df[df[side] == label][side + "Weight"].sum() 

331 if i == 0: 

332 labelWidths["bottom"] = 0 

333 labelWidths["top"] = labelWidths[side] 

334 else: 

335 bottomWidth = widths[labels[i - 1]]["top"] + gap 

336 weightedSum = 0.02 * df[side + "Weight"].sum() 

337 labelWidths["bottom"] = bottomWidth + weightedSum 

338 labelWidths["top"] = labelWidths["bottom"] + labelWidths[side] 

339 topEdge = labelWidths["top"] 

340 widths[label] = labelWidths 

341 LOGGER.debug("%s position of '%s' : %s", side, label, labelWidths) 

342 

343 return widths, topEdge 

344 

345 

346def _draw_label(widths1, widths2, minDistanceOfLabels=150000): 

347 return ( 

348 np.abs( 

349 (widths1["top"] - 0.5 * (widths1["top"] - widths1["bottom"])) 

350 - (widths2["top"] - 0.5 * (widths2["top"] - widths2["bottom"])) 

351 ) 

352 < minDistanceOfLabels 

353 )