Coverage for MPP/sankey_gap.py: 71%
161 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 14:59 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 14:59 +0200
1# -*- coding: utf-8 -*-
2r"""
3Produces simple Sankey Diagrams with matplotlib.
4@author: Anneya Golob & marcomanz & pierre-sassoulas & jorwoods & vgalisson
5 .-.
6 .--.( ).--.
7 <-. .-.-.(.-> )_ .--.
8 `-`( )-' `) )
9 (o o ) `)`-'
10 ( ) ,)
11 ( () ) )
12 `---"\ , , ,/`
13 `--' `--' `--'
14 | | | |
15 | | | |
16 ' | ' |
17"""
19# fmt: off
20import warnings
21import logging
22from collections import defaultdict
24import matplotlib.pyplot as plt
25import numpy as np
26import pandas as pd
27import seaborn as sns
28# fmt: on
30LOGGER = logging.getLogger(__name__)
33class PySankeyException(Exception):
34 """Generic PySankey Exception."""
37class NullsInFrame(PySankeyException):
38 pass
41class LabelMismatch(PySankeyException):
42 pass
45def check_data_matches_labels(labels, data, side):
46 """Check whether or not data matches labels.
48 Raise a LabelMismatch Exception if not."""
49 if len(labels) > 0: 49 ↛ exitline 49 didn't return from function 'check_data_matches_labels' because the condition on line 49 was always true
50 if isinstance(data, list): 50 ↛ 51line 50 didn't jump to line 51 because the condition on line 50 was never true
51 data = set(data)
52 if isinstance(data, pd.Series): 52 ↛ 54line 52 didn't jump to line 54 because the condition on line 52 was always true
53 data = set(data.unique().tolist())
54 if isinstance(labels, list): 54 ↛ 56line 54 didn't jump to line 56 because the condition on line 54 was always true
55 labels = set(labels)
56 if labels != data: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 msg = "\n"
58 if len(labels) <= 20:
59 msg = "Labels: " + ",".join(labels) + "\n"
60 if len(data) < 20:
61 msg += "Data: " + ",".join(data)
62 raise LabelMismatch(
63 "{0} labels and data do not match.{1}".format(side, msg)
64 )
67def sankey(
68 left,
69 right,
70 leftWeight=None,
71 rightWeight=None,
72 colorDict=None,
73 leftLabels=None,
74 rightLabels=None,
75 aspect=4,
76 rightColor=False,
77 fontsize="medium",
78 figureName=None,
79 closePlot=False,
80 figSize=None,
81 ax=None,
82):
83 """
84 Make Sankey Diagram showing flow from left-->right
86 Inputs:
87 left = NumPy array of object labels on the left of the diagram
88 right = NumPy array of corresponding labels on the right of the diagram
89 len(right) == len(left)
90 leftWeight = NumPy array of weights for each strip starting from the
91 left of the diagram, if not specified 1 is assigned
92 rightWeight = NumPy array of weights for each strip starting from the
93 right of the diagram, if not specified the corresponding leftWeight
94 is assigned
95 colorDict = Dictionary of colors to use for each label
96 {'label':'color'}
97 leftLabels = order of the left labels in the diagram
98 rightLabels = order of the right labels in the diagram
99 aspect = vertical extent of the diagram in units of horizontal extent
100 rightColor = If true, each strip in the diagram will be be colored
101 according to its left label
102 figSize = tuple setting the width and height of the sankey diagram.
103 Defaults to current figure size
104 ax = optional, matplotlib axes to plot on, otherwise uses current axes.
105 Output:
106 ax : matplotlib Axes
107 """
108 warn = []
109 if figureName is not None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 msg = "use of figureName in sankey() is deprecated"
111 warnings.warn(msg, DeprecationWarning)
112 warn.append(msg[7:-14])
113 if closePlot is not False: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 msg = "use of closePlot in sankey() is deprecated"
115 warnings.warn(msg, DeprecationWarning)
116 warn.append(msg[7:-14])
117 if figSize is not None: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 msg = "use of figSize in sankey() is deprecated"
119 warnings.warn(msg, DeprecationWarning)
120 warn.append(msg[7:-14])
122 if warn: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 LOGGER.warning(
124 " The following arguments are deprecated and should be removed: %s",
125 ", ".join(warn),
126 )
128 if ax is None: 128 ↛ 131line 128 didn't jump to line 131 because the condition on line 128 was always true
129 ax = plt.gca()
131 if leftWeight is None: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true
132 leftWeight = []
133 if rightWeight is None: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true
134 rightWeight = []
135 if leftLabels is None: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true
136 leftLabels = []
137 if rightLabels is None: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 rightLabels = []
139 # Check weights
140 if len(leftWeight) == 0: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true
141 leftWeight = np.ones(len(left))
143 if len(rightWeight) == 0: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true
144 rightWeight = leftWeight
146 # plt.rc("text", usetex=False)
147 # plt.rc("font", family="serif")
149 # Create Dataframe
150 if isinstance(left, pd.Series): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 left = left.reset_index(drop=True)
152 if isinstance(right, pd.Series): 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true
153 right = right.reset_index(drop=True)
154 dataFrame = pd.DataFrame(
155 {
156 "left": left,
157 "right": right,
158 "leftWeight": leftWeight,
159 "rightWeight": rightWeight,
160 },
161 index=range(len(left)),
162 )
164 if len(dataFrame[(dataFrame.left.isnull()) | (dataFrame.right.isnull())]): 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true
165 raise NullsInFrame("Sankey graph does not support null values.")
167 # Identify all labels that appear 'left' or 'right'
168 allLabels = pd.Series(
169 np.r_[dataFrame.left.unique(), dataFrame.right.unique()]
170 ).unique()
171 LOGGER.debug("Labels to handle : %s", allLabels)
173 # Identify left labels
174 if len(leftLabels) == 0: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 leftLabels = pd.Series(dataFrame.left.unique()).unique()
176 else:
177 check_data_matches_labels(leftLabels, dataFrame["left"], "left")
179 # Identify right labels
180 if len(rightLabels) == 0: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 rightLabels = pd.Series(dataFrame.right.unique()).unique()
182 else:
183 check_data_matches_labels(rightLabels, dataFrame["right"], "right")
185 # If no colorDict given, make one
186 if colorDict is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 colorDict = {}
188 palette = "hls"
189 colorPalette = sns.color_palette(palette, len(allLabels))
190 for i, label in enumerate(allLabels):
191 colorDict[label] = colorPalette[i]
192 LOGGER.debug("The colordict value are : %s", colorDict)
194 # Determine widths of individual strips
195 ns_l = defaultdict()
196 ns_r = defaultdict()
197 for leftLabel in leftLabels:
198 leftDict = {}
199 rightDict = {}
200 for rightLabel in rightLabels:
201 leftDict[rightLabel] = dataFrame[
202 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
203 ].leftWeight.sum()
204 rightDict[rightLabel] = dataFrame[
205 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
206 ].rightWeight.sum()
207 ns_l[leftLabel] = leftDict
208 ns_r[leftLabel] = rightDict
210 # Determine positions of left label patches and total widths
211 leftWidths, topEdge = _get_positions_and_total_widths(dataFrame, leftLabels, "left")
213 # Determine positions of right label patches and total widths
214 rightWidths, topEdge = _get_positions_and_total_widths(
215 dataFrame, rightLabels, "right"
216 )
218 # Total vertical extent of diagram
219 xMax = topEdge / aspect
221 previousleftlabel = ""
222 # Draw vertical bars on left and right of each label's section & print label
223 for vall, leftLabel in enumerate(leftLabels):
224 if vall != 0:
225 if _draw_label(leftWidths[leftLabel], leftWidths[previousleftlabel]):
226 continue
227 ax.text(
228 -0.05 * xMax,
229 leftWidths[leftLabel]["bottom"] + 0.5 * leftWidths[leftLabel]["left"],
230 rf"\textbf{{{leftLabel}}}",
231 {"ha": "right", "va": "center"},
232 fontsize=fontsize,
233 zorder=2,
234 )
235 previousleftlabel = leftLabel
236 previousrightlabel = ""
237 for valr, rightLabel in enumerate(rightLabels):
238 if valr != 0:
239 if _draw_label(rightWidths[rightLabel], rightWidths[previousrightlabel]):
240 continue
241 ax.text(
242 1.05 * xMax,
243 rightWidths[rightLabel]["bottom"] + 0.5 * rightWidths[rightLabel]["right"],
244 rf"\textbf{{{rightLabel}}}",
245 {"ha": "left", "va": "center"},
246 fontsize=fontsize,
247 zorder=2,
248 )
249 previousrightlabel = rightLabel
251 ymin, ymax = None, None
252 # Plot strips
253 for vall, leftLabel in enumerate(leftLabels):
254 for valr, rightLabel in enumerate(rightLabels):
255 labelColor = leftLabel
256 if rightColor: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true
257 labelColor = rightLabel
258 if (
259 len(
260 dataFrame[
261 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
262 ]
263 )
264 > 0
265 ):
266 # Create array of y values for each strip, half at left value,
267 # half at right, convolve
268 ys_d = np.array(
269 50 * [leftWidths[leftLabel]["bottom"]]
270 + 50 * [rightWidths[rightLabel]["bottom"]]
271 )
272 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
273 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
274 ys_u = np.array(
275 50 * [leftWidths[leftLabel]["bottom"] + ns_l[leftLabel][rightLabel]]
276 + 50
277 * [rightWidths[rightLabel]["bottom"] + ns_r[leftLabel][rightLabel]]
278 )
279 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")
280 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")
282 yrange = np.subtract(*ax.get_ylim()[::-1])
283 relative_width = np.mean(ys_u - ys_d) / yrange
285 # Update bottom edges at each label so next strip starts at the right place
286 leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel]
287 rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel]
288 ax.fill_between(
289 np.linspace(-0.013, 1.013, len(ys_d)) * xMax,
290 ys_d,
291 ys_u,
292 alpha=1.0,
293 zorder=1,
294 facecolor=colorDict[labelColor],
295 )
297 if ymin is None:
298 ymin = ys_d.min()
299 ymin = np.min([ys_d.min(), ys_u.min(), ymin])
301 if ymax is None:
302 ymax = ys_d.max()
303 ymax = np.max([ys_d.max(), ys_u.max(), ymax])
305 ax.set_ylim(ymin, ymax)
306 ax.axis("off")
308 if figSize is not None: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 plt.gcf().set_size_inches(figSize)
311 if figureName is not None: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 fileName = "{}.png".format(figureName)
313 plt.savefig(fileName, bbox_inches="tight", dpi=150)
314 LOGGER.info("Sankey diagram generated in '%s'", fileName)
315 if closePlot: 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true
316 plt.close()
318 return ax
321def _get_positions_and_total_widths(df, labels, side):
322 """Determine positions of label patches and total widths"""
323 # add gap
324 gap = 50000
325 # print(f'gap : {gap}')
326 widths = defaultdict()
327 for i, label in enumerate(labels):
328 labelWidths = {}
329 labelWidths[side] = df[df[side] == label][side + "Weight"].sum()
330 if i == 0:
331 labelWidths["bottom"] = 0
332 labelWidths["top"] = labelWidths[side]
333 else:
334 bottomWidth = widths[labels[i - 1]]["top"] + gap
335 weightedSum = 0.02 * df[side + "Weight"].sum()
336 labelWidths["bottom"] = bottomWidth + weightedSum
337 labelWidths["top"] = labelWidths["bottom"] + labelWidths[side]
338 topEdge = labelWidths["top"]
339 widths[label] = labelWidths
340 LOGGER.debug("%s position of '%s' : %s", side, label, labelWidths)
342 return widths, topEdge
345def _draw_label(widths1, widths2, minDistanceOfLabels=150000):
346 return (
347 np.abs(
348 (widths1["top"] - 0.5 * (widths1["top"] - widths1["bottom"]))
349 - (widths2["top"] - 0.5 * (widths2["top"] - widths2["bottom"]))
350 )
351 < minDistanceOfLabels
352 )