Coverage for MPP/sankey_gap.py: 71%

161 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 14:59 +0200

1# -*- coding: utf-8 -*- 

2r""" 

3Produces simple Sankey Diagrams with matplotlib. 

4@author: Anneya Golob & marcomanz & pierre-sassoulas & jorwoods & vgalisson 

5 .-. 

6 .--.( ).--. 

7 <-. .-.-.(.-> )_ .--. 

8 `-`( )-' `) ) 

9 (o o ) `)`-' 

10 ( ) ,) 

11 ( () ) ) 

12 `---"\ , , ,/` 

13 `--' `--' `--' 

14 | | | | 

15 | | | | 

16 ' | ' | 

17""" 

18 

19# fmt: off 

20import warnings 

21import logging 

22from collections import defaultdict 

23 

24import matplotlib.pyplot as plt 

25import numpy as np 

26import pandas as pd 

27import seaborn as sns 

28# fmt: on 

29 

30LOGGER = logging.getLogger(__name__) 

31 

32 

33class PySankeyException(Exception): 

34 """Generic PySankey Exception.""" 

35 

36 

37class NullsInFrame(PySankeyException): 

38 pass 

39 

40 

41class LabelMismatch(PySankeyException): 

42 pass 

43 

44 

45def check_data_matches_labels(labels, data, side): 

46 """Check whether or not data matches labels. 

47 

48 Raise a LabelMismatch Exception if not.""" 

49 if len(labels) > 0: 49 ↛ exitline 49 didn't return from function 'check_data_matches_labels' because the condition on line 49 was always true

50 if isinstance(data, list): 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true

51 data = set(data) 

52 if isinstance(data, pd.Series): 52 ↛ 54line 52 didn't jump to line 54 because the condition on line 52 was always true

53 data = set(data.unique().tolist()) 

54 if isinstance(labels, list): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true

55 labels = set(labels) 

56 if labels != data: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 msg = "\n" 

58 if len(labels) <= 20: 

59 msg = "Labels: " + ",".join(labels) + "\n" 

60 if len(data) < 20: 

61 msg += "Data: " + ",".join(data) 

62 raise LabelMismatch( 

63 "{0} labels and data do not match.{1}".format(side, msg) 

64 ) 

65 

66 

67def sankey( 

68 left, 

69 right, 

70 leftWeight=None, 

71 rightWeight=None, 

72 colorDict=None, 

73 leftLabels=None, 

74 rightLabels=None, 

75 aspect=4, 

76 rightColor=False, 

77 fontsize="medium", 

78 figureName=None, 

79 closePlot=False, 

80 figSize=None, 

81 ax=None, 

82): 

83 """ 

84 Make Sankey Diagram showing flow from left-->right 

85 

86 Inputs: 

87 left = NumPy array of object labels on the left of the diagram 

88 right = NumPy array of corresponding labels on the right of the diagram 

89 len(right) == len(left) 

90 leftWeight = NumPy array of weights for each strip starting from the 

91 left of the diagram, if not specified 1 is assigned 

92 rightWeight = NumPy array of weights for each strip starting from the 

93 right of the diagram, if not specified the corresponding leftWeight 

94 is assigned 

95 colorDict = Dictionary of colors to use for each label 

96 {'label':'color'} 

97 leftLabels = order of the left labels in the diagram 

98 rightLabels = order of the right labels in the diagram 

99 aspect = vertical extent of the diagram in units of horizontal extent 

100 rightColor = If true, each strip in the diagram will be be colored 

101 according to its left label 

102 figSize = tuple setting the width and height of the sankey diagram. 

103 Defaults to current figure size 

104 ax = optional, matplotlib axes to plot on, otherwise uses current axes. 

105 Output: 

106 ax : matplotlib Axes 

107 """ 

108 warn = [] 

109 if figureName is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 msg = "use of figureName in sankey() is deprecated" 

111 warnings.warn(msg, DeprecationWarning) 

112 warn.append(msg[7:-14]) 

113 if closePlot is not False: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 msg = "use of closePlot in sankey() is deprecated" 

115 warnings.warn(msg, DeprecationWarning) 

116 warn.append(msg[7:-14]) 

117 if figSize is not None: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 msg = "use of figSize in sankey() is deprecated" 

119 warnings.warn(msg, DeprecationWarning) 

120 warn.append(msg[7:-14]) 

121 

122 if warn: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 LOGGER.warning( 

124 " The following arguments are deprecated and should be removed: %s", 

125 ", ".join(warn), 

126 ) 

127 

128 if ax is None: 128 ↛ 131line 128 didn't jump to line 131 because the condition on line 128 was always true

129 ax = plt.gca() 

130 

131 if leftWeight is None: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 leftWeight = [] 

133 if rightWeight is None: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 rightWeight = [] 

135 if leftLabels is None: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true

136 leftLabels = [] 

137 if rightLabels is None: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 rightLabels = [] 

139 # Check weights 

140 if len(leftWeight) == 0: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 leftWeight = np.ones(len(left)) 

142 

143 if len(rightWeight) == 0: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true

144 rightWeight = leftWeight 

145 

146 # plt.rc("text", usetex=False) 

147 # plt.rc("font", family="serif") 

148 

149 # Create Dataframe 

150 if isinstance(left, pd.Series): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 left = left.reset_index(drop=True) 

152 if isinstance(right, pd.Series): 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true

153 right = right.reset_index(drop=True) 

154 dataFrame = pd.DataFrame( 

155 { 

156 "left": left, 

157 "right": right, 

158 "leftWeight": leftWeight, 

159 "rightWeight": rightWeight, 

160 }, 

161 index=range(len(left)), 

162 ) 

163 

164 if len(dataFrame[(dataFrame.left.isnull()) | (dataFrame.right.isnull())]): 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true

165 raise NullsInFrame("Sankey graph does not support null values.") 

166 

167 # Identify all labels that appear 'left' or 'right' 

168 allLabels = pd.Series( 

169 np.r_[dataFrame.left.unique(), dataFrame.right.unique()] 

170 ).unique() 

171 LOGGER.debug("Labels to handle : %s", allLabels) 

172 

173 # Identify left labels 

174 if len(leftLabels) == 0: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 leftLabels = pd.Series(dataFrame.left.unique()).unique() 

176 else: 

177 check_data_matches_labels(leftLabels, dataFrame["left"], "left") 

178 

179 # Identify right labels 

180 if len(rightLabels) == 0: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 rightLabels = pd.Series(dataFrame.right.unique()).unique() 

182 else: 

183 check_data_matches_labels(rightLabels, dataFrame["right"], "right") 

184 

185 # If no colorDict given, make one 

186 if colorDict is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 colorDict = {} 

188 palette = "hls" 

189 colorPalette = sns.color_palette(palette, len(allLabels)) 

190 for i, label in enumerate(allLabels): 

191 colorDict[label] = colorPalette[i] 

192 LOGGER.debug("The colordict value are : %s", colorDict) 

193 

194 # Determine widths of individual strips 

195 ns_l = defaultdict() 

196 ns_r = defaultdict() 

197 for leftLabel in leftLabels: 

198 leftDict = {} 

199 rightDict = {} 

200 for rightLabel in rightLabels: 

201 leftDict[rightLabel] = dataFrame[ 

202 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

203 ].leftWeight.sum() 

204 rightDict[rightLabel] = dataFrame[ 

205 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

206 ].rightWeight.sum() 

207 ns_l[leftLabel] = leftDict 

208 ns_r[leftLabel] = rightDict 

209 

210 # Determine positions of left label patches and total widths 

211 leftWidths, topEdge = _get_positions_and_total_widths(dataFrame, leftLabels, "left") 

212 

213 # Determine positions of right label patches and total widths 

214 rightWidths, topEdge = _get_positions_and_total_widths( 

215 dataFrame, rightLabels, "right" 

216 ) 

217 

218 # Total vertical extent of diagram 

219 xMax = topEdge / aspect 

220 

221 previousleftlabel = "" 

222 # Draw vertical bars on left and right of each label's section & print label 

223 for vall, leftLabel in enumerate(leftLabels): 

224 if vall != 0: 

225 if _draw_label(leftWidths[leftLabel], leftWidths[previousleftlabel]): 

226 continue 

227 ax.text( 

228 -0.05 * xMax, 

229 leftWidths[leftLabel]["bottom"] + 0.5 * leftWidths[leftLabel]["left"], 

230 rf"\textbf{{{leftLabel}}}", 

231 {"ha": "right", "va": "center"}, 

232 fontsize=fontsize, 

233 zorder=2, 

234 ) 

235 previousleftlabel = leftLabel 

236 previousrightlabel = "" 

237 for valr, rightLabel in enumerate(rightLabels): 

238 if valr != 0: 

239 if _draw_label(rightWidths[rightLabel], rightWidths[previousrightlabel]): 

240 continue 

241 ax.text( 

242 1.05 * xMax, 

243 rightWidths[rightLabel]["bottom"] + 0.5 * rightWidths[rightLabel]["right"], 

244 rf"\textbf{{{rightLabel}}}", 

245 {"ha": "left", "va": "center"}, 

246 fontsize=fontsize, 

247 zorder=2, 

248 ) 

249 previousrightlabel = rightLabel 

250 

251 ymin, ymax = None, None 

252 # Plot strips 

253 for vall, leftLabel in enumerate(leftLabels): 

254 for valr, rightLabel in enumerate(rightLabels): 

255 labelColor = leftLabel 

256 if rightColor: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true

257 labelColor = rightLabel 

258 if ( 

259 len( 

260 dataFrame[ 

261 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel) 

262 ] 

263 ) 

264 > 0 

265 ): 

266 # Create array of y values for each strip, half at left value, 

267 # half at right, convolve 

268 ys_d = np.array( 

269 50 * [leftWidths[leftLabel]["bottom"]] 

270 + 50 * [rightWidths[rightLabel]["bottom"]] 

271 ) 

272 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") 

273 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid") 

274 ys_u = np.array( 

275 50 * [leftWidths[leftLabel]["bottom"] + ns_l[leftLabel][rightLabel]] 

276 + 50 

277 * [rightWidths[rightLabel]["bottom"] + ns_r[leftLabel][rightLabel]] 

278 ) 

279 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") 

280 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid") 

281 

282 yrange = np.subtract(*ax.get_ylim()[::-1]) 

283 relative_width = np.mean(ys_u - ys_d) / yrange 

284 

285 # Update bottom edges at each label so next strip starts at the right place 

286 leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel] 

287 rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel] 

288 ax.fill_between( 

289 np.linspace(-0.013, 1.013, len(ys_d)) * xMax, 

290 ys_d, 

291 ys_u, 

292 alpha=1.0, 

293 zorder=1, 

294 facecolor=colorDict[labelColor], 

295 ) 

296 

297 if ymin is None: 

298 ymin = ys_d.min() 

299 ymin = np.min([ys_d.min(), ys_u.min(), ymin]) 

300 

301 if ymax is None: 

302 ymax = ys_d.max() 

303 ymax = np.max([ys_d.max(), ys_u.max(), ymax]) 

304 

305 ax.set_ylim(ymin, ymax) 

306 ax.axis("off") 

307 

308 if figSize is not None: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 plt.gcf().set_size_inches(figSize) 

310 

311 if figureName is not None: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true

312 fileName = "{}.png".format(figureName) 

313 plt.savefig(fileName, bbox_inches="tight", dpi=150) 

314 LOGGER.info("Sankey diagram generated in '%s'", fileName) 

315 if closePlot: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 plt.close() 

317 

318 return ax 

319 

320 

321def _get_positions_and_total_widths(df, labels, side): 

322 """Determine positions of label patches and total widths""" 

323 # add gap 

324 gap = 50000 

325 # print(f'gap : {gap}') 

326 widths = defaultdict() 

327 for i, label in enumerate(labels): 

328 labelWidths = {} 

329 labelWidths[side] = df[df[side] == label][side + "Weight"].sum() 

330 if i == 0: 

331 labelWidths["bottom"] = 0 

332 labelWidths["top"] = labelWidths[side] 

333 else: 

334 bottomWidth = widths[labels[i - 1]]["top"] + gap 

335 weightedSum = 0.02 * df[side + "Weight"].sum() 

336 labelWidths["bottom"] = bottomWidth + weightedSum 

337 labelWidths["top"] = labelWidths["bottom"] + labelWidths[side] 

338 topEdge = labelWidths["top"] 

339 widths[label] = labelWidths 

340 LOGGER.debug("%s position of '%s' : %s", side, label, labelWidths) 

341 

342 return widths, topEdge 

343 

344 

345def _draw_label(widths1, widths2, minDistanceOfLabels=150000): 

346 return ( 

347 np.abs( 

348 (widths1["top"] - 0.5 * (widths1["top"] - widths1["bottom"])) 

349 - (widths2["top"] - 0.5 * (widths2["top"] - widths2["bottom"])) 

350 ) 

351 < minDistanceOfLabels 

352 )