Coverage for C:\src\imod-python\imod\visualize\waterbalance.py: 10%
61 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 10:26 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 10:26 +0200
1import itertools
3import matplotlib.pyplot as plt
4import numpy as np
5import pandas as pd
8def _draw_bars(ax, x, df, labels, barwidth, colors):
9 ndates, _ = df.shape
10 bottoms = np.hstack([np.zeros((ndates, 1)), df.cumsum(axis=1).values]).T[:-1]
11 heights = df.values.T
12 if colors is None:
13 for label, bottom, height in zip(labels, bottoms, heights):
14 ax.bar(
15 x,
16 bottom=bottom,
17 height=height,
18 width=barwidth,
19 edgecolor="k",
20 label=label,
21 )
22 else:
23 for label, bottom, height, color in zip(labels, bottoms, heights, colors):
24 ax.bar(
25 x,
26 bottom=bottom,
27 height=height,
28 width=barwidth,
29 edgecolor="k",
30 label=label,
31 color=color,
32 )
35def waterbalance_barchart(
36 df,
37 inflows,
38 outflows,
39 datecolumn=None,
40 format="%Y-%m-%d",
41 ax=None,
42 unit=None,
43 colors=None,
44):
45 """
46 Parameters
47 ----------
48 df : pandas.DataFrame
49 The dataframe containing the water balance data.
50 inflows : listlike of str
51 outflows : listlike of str
52 datecolumn : str, optional
53 format : str, optional,
54 ax : matplotlib.Axes, optional
55 unit : str, optional
56 colors : listlike of strings or tuples
58 Returns
59 -------
60 ax : matplotlib.Axes
62 Examples
63 --------
65 >>> fig, ax = plt.subplots()
66 >>> imod.visualize.waterbalance_barchart(
67 >>> ax=ax,
68 >>> df=df,
69 >>> inflows=["Rainfall", "River upstream"],
70 >>> outflows=["Evapotranspiration", "Discharge to Sea"],
71 >>> datecolumn="Time",
72 >>> format="%Y-%m-%d",
73 >>> unit="m3/d",
74 >>> colors=["#ca0020", "#f4a582", "#92c5de", "#0571b0"],
75 >>> )
76 >>> fig.savefig("Waterbalance.png", dpi=300, bbox_inches="tight")
78 """
79 # Do some checks
80 if not isinstance(df, pd.DataFrame):
81 raise TypeError("df should be a pandas.DataFrame")
82 if datecolumn is not None:
83 if datecolumn not in df.columns:
84 raise ValueError(f"datecolumn {datecolumn} not in df")
85 for column in itertools.chain(inflows, outflows):
86 if column not in df:
87 raise ValueError(f"{column} not in df")
88 if colors is not None:
89 ncolors = len(colors)
90 nflows = len(inflows + outflows)
91 if ncolors < nflows:
92 raise ValueError(
93 f"Not enough colors: Number of flows is {nflows}, while number of colors is {ncolors}"
94 )
95 # Deal with colors, takes both dict and list
96 if isinstance(colors, dict):
97 incolors = [colors[k] for k in inflows]
98 outcolors = [colors[k] for k in outflows]
99 elif isinstance(colors, (tuple, list)):
100 incolors = colors[: len(inflows)]
101 outcolors = colors[len(inflows) :]
102 else:
103 incolors = None
104 outcolors = None
106 # Determine x position
107 ndates, _ = df.shape
108 barwidth = 1.0
109 r1 = np.arange(0.0, ndates * barwidth * 3, barwidth * 3)
110 r2 = np.array([x + barwidth for x in r1])
111 r_between = 0.5 * (r1 + r2)
113 # Grab ax if not provided directly
114 if ax is None:
115 ax = plt.gca()
117 # Draw inflows
118 _draw_bars(
119 ax=ax, x=r1, df=df[inflows], labels=inflows, barwidth=barwidth, colors=incolors
120 )
121 # Draw outflows
122 _draw_bars(
123 ax=ax,
124 x=r2,
125 df=df[outflows],
126 labels=outflows,
127 barwidth=barwidth,
128 colors=outcolors,
129 )
131 # Place xticks
132 xticks_location = list(itertools.chain(*zip(r1, r_between, r2)))
133 # Collect the labels, and format them as desired
134 # TODO: might not work for all dateformats?
135 xticks_labels = []
136 if datecolumn is None:
137 dates = df.index
138 else:
139 dates = df[datecolumn]
140 for date in dates:
141 # Place the date labels two lines (two \n) below the minor labels ("in", "out")
142 xticks_labels.extend(["in", f"\n\n{date.strftime(format)}", "out"])
144 # Adjust the ticks. Lengthen the major ticks, so they extend down to the dates
145 ax.tick_params(axis="x", which="major", bottom=False, top=False, labelbottom=True)
146 ax.tick_params(
147 axis="x",
148 which="minor",
149 bottom=True,
150 top=False,
151 labelbottom=False,
152 length=barwidth * 45,
153 )
154 ax.xaxis.set_ticks(xticks_location)
155 ax.xaxis.set_ticklabels(xticks_labels)
156 xticks_location_minor = r1[1:] - barwidth
157 ax.xaxis.set_ticks(xticks_location_minor, minor=True)
159 # Create a legend on the right side of the chart
160 ax.legend(
161 loc="upper left",
162 bbox_to_anchor=(1.03, 1.0),
163 ncol=2,
164 borderaxespad=0,
165 frameon=True,
166 )
168 # Set a unit on the y-axis
169 if unit is not None:
170 ax.yaxis.set_label(unit)
172 return ax