Coverage for MPT/sankey_gap.py: 72%
162 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-08 18:27 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-08 18:27 +0200
1# -*- coding: utf-8 -*-
2r"""
3Produces simple Sankey Diagrams with matplotlib.
4@author: Anneya Golob & marcomanz & pierre-sassoulas & jorwoods & vgalisson
5 .-.
6 .--.( ).--.
7 <-. .-.-.(.-> )_ .--.
8 `-`( )-' `) )
9 (o o ) `)`-'
10 ( ) ,)
11 ( () ) )
12 `---"\ , , ,/`
13 `--' `--' `--'
14 | | | |
15 | | | |
16 ' | ' |
17"""
19# fmt: off
20import warnings
21import logging
22from collections import defaultdict
24import matplotlib.pyplot as plt
25import matplotlib.patheffects as path_effects
26import numpy as np
27import pandas as pd
28import seaborn as sns
29# fmt: on
31LOGGER = logging.getLogger(__name__)
34class PySankeyException(Exception):
35 """Generic PySankey Exception."""
38class NullsInFrame(PySankeyException):
39 pass
42class LabelMismatch(PySankeyException):
43 pass
46def check_data_matches_labels(labels, data, side):
47 """Check whether or not data matches labels.
49 Raise a LabelMismatch Exception if not."""
50 if len(labels) > 0: 50 ↛ exitline 50 didn't return from function 'check_data_matches_labels' because the condition on line 50 was always true
51 if isinstance(data, list): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 data = set(data)
53 if isinstance(data, pd.Series): 53 ↛ 55line 53 didn't jump to line 55 because the condition on line 53 was always true
54 data = set(data.unique().tolist())
55 if isinstance(labels, list): 55 ↛ 57line 55 didn't jump to line 57 because the condition on line 55 was always true
56 labels = set(labels)
57 if labels != data: 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true
58 msg = "\n"
59 if len(labels) <= 20:
60 msg = "Labels: " + ",".join(labels) + "\n"
61 if len(data) < 20:
62 msg += "Data: " + ",".join(data)
63 raise LabelMismatch(
64 "{0} labels and data do not match.{1}".format(side, msg)
65 )
68def sankey(
69 left,
70 right,
71 leftWeight=None,
72 rightWeight=None,
73 colorDict=None,
74 leftLabels=None,
75 rightLabels=None,
76 aspect=4,
77 rightColor=False,
78 fontsize="medium",
79 figureName=None,
80 closePlot=False,
81 figSize=None,
82 ax=None,
83):
84 """
85 Make Sankey Diagram showing flow from left-->right
87 Inputs:
88 left = NumPy array of object labels on the left of the diagram
89 right = NumPy array of corresponding labels on the right of the diagram
90 len(right) == len(left)
91 leftWeight = NumPy array of weights for each strip starting from the
92 left of the diagram, if not specified 1 is assigned
93 rightWeight = NumPy array of weights for each strip starting from the
94 right of the diagram, if not specified the corresponding leftWeight
95 is assigned
96 colorDict = Dictionary of colors to use for each label
97 {'label':'color'}
98 leftLabels = order of the left labels in the diagram
99 rightLabels = order of the right labels in the diagram
100 aspect = vertical extent of the diagram in units of horizontal extent
101 rightColor = If true, each strip in the diagram will be be colored
102 according to its left label
103 figSize = tuple setting the width and height of the sankey diagram.
104 Defaults to current figure size
105 ax = optional, matplotlib axes to plot on, otherwise uses current axes.
106 Output:
107 ax : matplotlib Axes
108 """
109 warn = []
110 if figureName is not None: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 msg = "use of figureName in sankey() is deprecated"
112 warnings.warn(msg, DeprecationWarning)
113 warn.append(msg[7:-14])
114 if closePlot is not False: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 msg = "use of closePlot in sankey() is deprecated"
116 warnings.warn(msg, DeprecationWarning)
117 warn.append(msg[7:-14])
118 if figSize is not None: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true
119 msg = "use of figSize in sankey() is deprecated"
120 warnings.warn(msg, DeprecationWarning)
121 warn.append(msg[7:-14])
123 if warn: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 LOGGER.warning(
125 " The following arguments are deprecated and should be removed: %s",
126 ", ".join(warn),
127 )
129 if ax is None: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true
130 ax = plt.gca()
132 if leftWeight is None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 leftWeight = []
134 if rightWeight is None: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 rightWeight = []
136 if leftLabels is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 leftLabels = []
138 if rightLabels is None: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true
139 rightLabels = []
140 # Check weights
141 if len(leftWeight) == 0: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 leftWeight = np.ones(len(left))
144 if len(rightWeight) == 0: 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 rightWeight = leftWeight
147 # plt.rc("text", usetex=False)
148 # plt.rc("font", family="serif")
150 # Create Dataframe
151 if isinstance(left, pd.Series): 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 left = left.reset_index(drop=True)
153 if isinstance(right, pd.Series): 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true
154 right = right.reset_index(drop=True)
155 dataFrame = pd.DataFrame(
156 {
157 "left": left,
158 "right": right,
159 "leftWeight": leftWeight,
160 "rightWeight": rightWeight,
161 },
162 index=range(len(left)),
163 )
165 if len(dataFrame[(dataFrame.left.isnull()) | (dataFrame.right.isnull())]): 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true
166 raise NullsInFrame("Sankey graph does not support null values.")
168 # Identify all labels that appear 'left' or 'right'
169 allLabels = pd.Series(
170 np.r_[dataFrame.left.unique(), dataFrame.right.unique()]
171 ).unique()
172 LOGGER.debug("Labels to handle : %s", allLabels)
174 # Identify left labels
175 if len(leftLabels) == 0: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 leftLabels = pd.Series(dataFrame.left.unique()).unique()
177 else:
178 check_data_matches_labels(leftLabels, dataFrame["left"], "left")
180 # Identify right labels
181 if len(rightLabels) == 0: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true
182 rightLabels = pd.Series(dataFrame.right.unique()).unique()
183 else:
184 check_data_matches_labels(rightLabels, dataFrame["right"], "right")
186 # If no colorDict given, make one
187 if colorDict is None: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 colorDict = {}
189 palette = "hls"
190 colorPalette = sns.color_palette(palette, len(allLabels))
191 for i, label in enumerate(allLabels):
192 colorDict[label] = colorPalette[i]
193 LOGGER.debug("The colordict value are : %s", colorDict)
195 # Determine widths of individual strips
196 ns_l = defaultdict()
197 ns_r = defaultdict()
198 for leftLabel in leftLabels:
199 leftDict = {}
200 rightDict = {}
201 for rightLabel in rightLabels:
202 leftDict[rightLabel] = dataFrame[
203 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
204 ].leftWeight.sum()
205 rightDict[rightLabel] = dataFrame[
206 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
207 ].rightWeight.sum()
208 ns_l[leftLabel] = leftDict
209 ns_r[leftLabel] = rightDict
211 # Determine positions of left label patches and total widths
212 leftWidths, topEdge = _get_positions_and_total_widths(dataFrame, leftLabels, "left")
214 # Determine positions of right label patches and total widths
215 rightWidths, topEdge = _get_positions_and_total_widths(
216 dataFrame, rightLabels, "right"
217 )
219 # Total vertical extent of diagram
220 xMax = topEdge / aspect
222 previousleftlabel = ""
223 # Draw vertical bars on left and right of each label's section & print label
224 for vall, leftLabel in enumerate(leftLabels):
225 if vall != 0:
226 if _draw_label(leftWidths[leftLabel], leftWidths[previousleftlabel]):
227 continue
228 ax.text(
229 -0.05 * xMax,
230 leftWidths[leftLabel]["bottom"] + 0.5 * leftWidths[leftLabel]["left"],
231 rf"\textbf{{{leftLabel}}}",
232 {"ha": "right", "va": "center"},
233 fontsize=fontsize,
234 zorder=2,
235 )
236 previousleftlabel = leftLabel
237 previousrightlabel = ""
238 for valr, rightLabel in enumerate(rightLabels):
239 if valr != 0:
240 if _draw_label(rightWidths[rightLabel], rightWidths[previousrightlabel]):
241 continue
242 ax.text(
243 1.05 * xMax,
244 rightWidths[rightLabel]["bottom"] + 0.5 * rightWidths[rightLabel]["right"],
245 rf"\textbf{{{rightLabel}}}",
246 {"ha": "left", "va": "center"},
247 fontsize=fontsize,
248 zorder=2,
249 )
250 previousrightlabel = rightLabel
252 ymin, ymax = None, None
253 # Plot strips
254 for vall, leftLabel in enumerate(leftLabels):
255 for valr, rightLabel in enumerate(rightLabels):
256 labelColor = leftLabel
257 if rightColor: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true
258 labelColor = rightLabel
259 if (
260 len(
261 dataFrame[
262 (dataFrame.left == leftLabel) & (dataFrame.right == rightLabel)
263 ]
264 )
265 > 0
266 ):
267 # Create array of y values for each strip, half at left value,
268 # half at right, convolve
269 ys_d = np.array(
270 50 * [leftWidths[leftLabel]["bottom"]]
271 + 50 * [rightWidths[rightLabel]["bottom"]]
272 )
273 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
274 ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
275 ys_u = np.array(
276 50 * [leftWidths[leftLabel]["bottom"] + ns_l[leftLabel][rightLabel]]
277 + 50
278 * [rightWidths[rightLabel]["bottom"] + ns_r[leftLabel][rightLabel]]
279 )
280 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")
281 ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")
283 yrange = np.subtract(*ax.get_ylim()[::-1])
284 relative_width = np.mean(ys_u - ys_d) / yrange
286 # Update bottom edges at each label so next strip starts at the right place
287 leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel]
288 rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel]
289 ax.fill_between(
290 np.linspace(-0.013, 1.013, len(ys_d)) * xMax,
291 ys_d,
292 ys_u,
293 alpha=1.0,
294 zorder=1,
295 facecolor=colorDict[labelColor],
296 )
298 if ymin is None:
299 ymin = ys_d.min()
300 ymin = np.min([ys_d.min(), ys_u.min(), ymin])
302 if ymax is None:
303 ymax = ys_d.max()
304 ymax = np.max([ys_d.max(), ys_u.max(), ymax])
306 ax.set_ylim(ymin, ymax)
307 ax.axis("off")
309 if figSize is not None: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 plt.gcf().set_size_inches(figSize)
312 if figureName is not None: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 fileName = "{}.png".format(figureName)
314 plt.savefig(fileName, bbox_inches="tight", dpi=150)
315 LOGGER.info("Sankey diagram generated in '%s'", fileName)
316 if closePlot: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true
317 plt.close()
319 return ax
322def _get_positions_and_total_widths(df, labels, side):
323 """Determine positions of label patches and total widths"""
324 # add gap
325 gap = 50000
326 # print(f'gap : {gap}')
327 widths = defaultdict()
328 for i, label in enumerate(labels):
329 labelWidths = {}
330 labelWidths[side] = df[df[side] == label][side + "Weight"].sum()
331 if i == 0:
332 labelWidths["bottom"] = 0
333 labelWidths["top"] = labelWidths[side]
334 else:
335 bottomWidth = widths[labels[i - 1]]["top"] + gap
336 weightedSum = 0.02 * df[side + "Weight"].sum()
337 labelWidths["bottom"] = bottomWidth + weightedSum
338 labelWidths["top"] = labelWidths["bottom"] + labelWidths[side]
339 topEdge = labelWidths["top"]
340 widths[label] = labelWidths
341 LOGGER.debug("%s position of '%s' : %s", side, label, labelWidths)
343 return widths, topEdge
346def _draw_label(widths1, widths2, minDistanceOfLabels=150000):
347 return (
348 np.abs(
349 (widths1["top"] - 0.5 * (widths1["top"] - widths1["bottom"]))
350 - (widths2["top"] - 0.5 * (widths2["top"] - widths2["bottom"]))
351 )
352 < minDistanceOfLabels
353 )