mgplot.bar_plot
Create bar plots using Matplotlib.
Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1"""Create bar plots using Matplotlib. 2 3Note: bar plots in Matplotlib are not the same as bar charts in other 4libraries. Bar plots are used to represent categorical data with 5rectangular bars. As a result, bar plots and line plots typically 6cannot be plotted on the same axes. 7""" 8 9from collections.abc import Sequence 10from typing import Any, Final, NotRequired, TypedDict, Unpack 11 12import matplotlib.patheffects as pe 13import matplotlib.pyplot as plt 14import numpy as np 15from matplotlib.axes import Axes 16from pandas import DataFrame, Period, Series 17 18from mgplot.axis_utils import map_periodindex, map_stringindex, set_labels 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs 20from mgplot.settings import DataT, get_setting 21from mgplot.utilities import ( 22 apply_defaults, 23 constrain_data, 24 default_rounding, 25 get_axes, 26 get_color_list, 27) 28 29# --- constants 30ME: Final[str] = "bar_plot" 31MAX_ANNOTATIONS: Final[int] = 30 32ADJUSTMENT_FACTOR: Final[float] = 0.02 33MIN_BAR_WIDTH: Final[float] = 0.0 34MAX_BAR_WIDTH: Final[float] = 1.0 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8 36DEFAULT_BAR_OFFSET: Final[float] = 0.5 37DEFAULT_MAX_TICKS: Final[int] = 10 38 39 40class BarKwargs(BaseKwargs): 41 """Keyword arguments for the bar_plot function.""" 42 43 # --- options for the entire bar plot 44 ax: NotRequired[Axes | None] 45 stacked: NotRequired[bool] 46 max_ticks: NotRequired[int] 47 plot_from: NotRequired[int | Period] 48 label_rotation: NotRequired[int | float] 49 # --- options for each bar ... 50 color: NotRequired[str | Sequence[str]] 51 label_series: NotRequired[bool | Sequence[bool]] 52 width: NotRequired[float | int | Sequence[float | int]] 53 zorder: NotRequired[int | float | Sequence[int | float]] 54 # --- options for bar annotations 55 annotate: NotRequired[bool] 56 fontsize: NotRequired[int | float | str] 57 fontname: NotRequired[str] 58 rounding: NotRequired[int] 59 rotation: NotRequired[int | float] 60 annotate_color: NotRequired[str] 61 above: NotRequired[bool] 62 63 64# --- functions 65class AnnoKwargs(TypedDict, total=False): 66 """TypedDict for the kwargs used in annotate_bars.""" 67 68 annotate: bool 69 fontsize: int | float | str 70 fontname: str 71 color: str 72 rotation: int | float 73 foreground: str # used for stroke effect on text 74 above: bool 75 rounding: bool | int # if True, uses default rounding; if int, uses that value 76 77 78def annotate_bars( 79 series: Series, 80 offset: float, 81 base: np.ndarray, 82 axes: Axes, 83 **anno_kwargs: Unpack[AnnoKwargs], 84) -> None: 85 """Bar plot annotations. 86 87 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 88 """ 89 # --- only annotate in limited circumstances 90 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 91 return 92 max_annotations = MAX_ANNOTATIONS 93 if len(series) > max_annotations: 94 return 95 96 # --- internal logic check 97 if len(base) != len(series): 98 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 99 return 100 101 # --- assemble the annotation parameters 102 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 103 annotate_style: dict[str, Any] = { 104 "fontsize": anno_kwargs.get("fontsize"), 105 "fontname": anno_kwargs.get("fontname"), 106 "color": anno_kwargs.get("color"), 107 "rotation": anno_kwargs.get("rotation"), 108 } 109 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 110 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 111 zero_correction = series.index.min() 112 113 # --- annotate each bar 114 for index, value in zip(series.index.astype(int), series, strict=True): 115 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 116 if above: 117 position += value 118 text = axes.text( 119 x=index + offset, 120 y=position, 121 s=f"{value:.{rounding}f}", 122 ha="center", 123 va="bottom" if value >= 0 else "top", 124 **annotate_style, 125 ) 126 if not above and "foreground" in anno_kwargs: 127 # apply a stroke-effect to within bar annotations 128 # to make them more readable with very small bars. 129 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]) 130 131 132class GroupedKwargs(TypedDict): 133 """TypedDict for the kwargs used in grouped.""" 134 135 color: Sequence[str] 136 width: Sequence[float | int] 137 label_series: Sequence[bool] 138 zorder: Sequence[int | float | None] 139 140 141def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 142 """Plot a grouped bar plot.""" 143 series_count = len(df.columns) 144 145 for i, col in enumerate(df.columns): 146 series = df[col] 147 if series.isna().all(): 148 continue 149 width = kwargs["width"][i] 150 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 151 width = DEFAULT_GROUPED_WIDTH 152 adjusted_width = width / series_count 153 # far-left + margin + halfway through one grouped column 154 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 155 offset = left + (i * adjusted_width) 156 foreground = kwargs["color"][i] 157 axes.bar( 158 x=series.index + offset, 159 height=series, 160 color=foreground, 161 width=adjusted_width, 162 zorder=kwargs["zorder"][i], 163 label=col if kwargs["label_series"][i] else f"_{col}_", 164 ) 165 anno_args["foreground"] = foreground 166 annotate_bars( 167 series=series, 168 offset=offset, 169 base=np.zeros(len(series)), 170 axes=axes, 171 **anno_args, 172 ) 173 174 175class StackedKwargs(TypedDict): 176 """TypedDict for the kwargs used in stacked.""" 177 178 color: Sequence[str] 179 width: Sequence[float | int] 180 label_series: Sequence[bool] 181 zorder: Sequence[int | float | None] 182 183 184def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 185 """Plot a stacked bar plot.""" 186 row_count = len(df) 187 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 188 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 189 for i, col in enumerate(df.columns): 190 series = df[col] 191 base = np.where(series >= 0, base_plus, base_minus) 192 foreground = kwargs["color"][i] 193 axes.bar( 194 x=series.index, 195 height=series, 196 bottom=base, 197 color=foreground, 198 width=kwargs["width"][i], 199 zorder=kwargs["zorder"][i], 200 label=col if kwargs["label_series"][i] else f"_{col}_", 201 ) 202 anno_args["foreground"] = foreground 203 annotate_bars( 204 series=series, 205 offset=0, 206 base=base, 207 axes=axes, 208 **anno_args, 209 ) 210 base_plus += np.where(series >= 0, series, 0) 211 base_minus += np.where(series < 0, series, 0) 212 213 214def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 215 """Create a bar plot from the given data. 216 217 Each column in the DataFrame will be stacked on top of each other, 218 with positive values above zero and negative values below zero. 219 220 Args: 221 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 222 **kwargs: BarKwargs - Additional keyword arguments for customization. 223 (see BarKwargs for details) 224 225 Note: This function does not assume all data is timeseries with a PeriodIndex. 226 227 Returns: 228 axes: Axes - The axes for the plot. 229 230 """ 231 # --- check the kwargs 232 report_kwargs(caller=ME, **kwargs) 233 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 234 235 # --- get the data 236 # no call to check_clean_timeseries here, as bar plots are not 237 # necessarily timeseries data. If the data is a Series, it will be 238 # converted to a DataFrame with a single column. 239 df = DataFrame(data) # really we are only plotting DataFrames 240 df, kwargs_d = constrain_data(df, **kwargs) 241 item_count = len(df.columns) 242 243 # --- deal with string indices 244 saved_strings = map_stringindex(df) 245 if saved_strings is not None: 246 df = saved_strings[0] 247 248 # --- deal with complete PeriodIndex indices 249 saved_pi = map_periodindex(df) 250 if saved_pi is not None: 251 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 252 253 # --- set up the default arguments 254 chart_defaults: dict[str, bool | int] = { 255 "stacked": False, 256 "max_ticks": DEFAULT_MAX_TICKS, 257 "label_series": item_count > 1, 258 "label_rotation": 0, 259 } 260 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 261 262 bar_defaults = { 263 "color": get_color_list(item_count), 264 "width": get_setting("bar_width"), 265 "label_series": item_count > 1, 266 "zorder": None, 267 } 268 above = kwargs_d.get("above", False) 269 anno_args: AnnoKwargs = { 270 "annotate": kwargs_d.get("annotate", False), 271 "fontsize": kwargs_d.get("fontsize", "small"), 272 "fontname": kwargs_d.get("fontname", "Helvetica"), 273 "rotation": kwargs_d.get("rotation", 0), 274 "rounding": kwargs_d.get("rounding", True), 275 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 276 "above": above, 277 } 278 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 279 280 # --- plot the data 281 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 282 if chart_args["stacked"]: 283 stacked(axes, df, anno_args, **bar_args) 284 else: 285 grouped(axes, df, anno_args, **bar_args) 286 287 # --- handle index labels and rotation 288 if saved_strings is not None: 289 axes.set_xticks(range(len(saved_strings[1]))) 290 axes.set_xticklabels(saved_strings[1]) 291 elif saved_pi is not None: 292 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 293 plt.xticks(rotation=chart_args["label_rotation"]) 294 295 return axes
41class BarKwargs(BaseKwargs): 42 """Keyword arguments for the bar_plot function.""" 43 44 # --- options for the entire bar plot 45 ax: NotRequired[Axes | None] 46 stacked: NotRequired[bool] 47 max_ticks: NotRequired[int] 48 plot_from: NotRequired[int | Period] 49 label_rotation: NotRequired[int | float] 50 # --- options for each bar ... 51 color: NotRequired[str | Sequence[str]] 52 label_series: NotRequired[bool | Sequence[bool]] 53 width: NotRequired[float | int | Sequence[float | int]] 54 zorder: NotRequired[int | float | Sequence[int | float]] 55 # --- options for bar annotations 56 annotate: NotRequired[bool] 57 fontsize: NotRequired[int | float | str] 58 fontname: NotRequired[str] 59 rounding: NotRequired[int] 60 rotation: NotRequired[int | float] 61 annotate_color: NotRequired[str] 62 above: NotRequired[bool]
Keyword arguments for the bar_plot function.
66class AnnoKwargs(TypedDict, total=False): 67 """TypedDict for the kwargs used in annotate_bars.""" 68 69 annotate: bool 70 fontsize: int | float | str 71 fontname: str 72 color: str 73 rotation: int | float 74 foreground: str # used for stroke effect on text 75 above: bool 76 rounding: bool | int # if True, uses default rounding; if int, uses that value
TypedDict for the kwargs used in annotate_bars.
79def annotate_bars( 80 series: Series, 81 offset: float, 82 base: np.ndarray, 83 axes: Axes, 84 **anno_kwargs: Unpack[AnnoKwargs], 85) -> None: 86 """Bar plot annotations. 87 88 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 89 """ 90 # --- only annotate in limited circumstances 91 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 92 return 93 max_annotations = MAX_ANNOTATIONS 94 if len(series) > max_annotations: 95 return 96 97 # --- internal logic check 98 if len(base) != len(series): 99 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 100 return 101 102 # --- assemble the annotation parameters 103 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 104 annotate_style: dict[str, Any] = { 105 "fontsize": anno_kwargs.get("fontsize"), 106 "fontname": anno_kwargs.get("fontname"), 107 "color": anno_kwargs.get("color"), 108 "rotation": anno_kwargs.get("rotation"), 109 } 110 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 111 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 112 zero_correction = series.index.min() 113 114 # --- annotate each bar 115 for index, value in zip(series.index.astype(int), series, strict=True): 116 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 117 if above: 118 position += value 119 text = axes.text( 120 x=index + offset, 121 y=position, 122 s=f"{value:.{rounding}f}", 123 ha="center", 124 va="bottom" if value >= 0 else "top", 125 **annotate_style, 126 ) 127 if not above and "foreground" in anno_kwargs: 128 # apply a stroke-effect to within bar annotations 129 # to make them more readable with very small bars. 130 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
133class GroupedKwargs(TypedDict): 134 """TypedDict for the kwargs used in grouped.""" 135 136 color: Sequence[str] 137 width: Sequence[float | int] 138 label_series: Sequence[bool] 139 zorder: Sequence[int | float | None]
TypedDict for the kwargs used in grouped.
142def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 143 """Plot a grouped bar plot.""" 144 series_count = len(df.columns) 145 146 for i, col in enumerate(df.columns): 147 series = df[col] 148 if series.isna().all(): 149 continue 150 width = kwargs["width"][i] 151 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 152 width = DEFAULT_GROUPED_WIDTH 153 adjusted_width = width / series_count 154 # far-left + margin + halfway through one grouped column 155 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 156 offset = left + (i * adjusted_width) 157 foreground = kwargs["color"][i] 158 axes.bar( 159 x=series.index + offset, 160 height=series, 161 color=foreground, 162 width=adjusted_width, 163 zorder=kwargs["zorder"][i], 164 label=col if kwargs["label_series"][i] else f"_{col}_", 165 ) 166 anno_args["foreground"] = foreground 167 annotate_bars( 168 series=series, 169 offset=offset, 170 base=np.zeros(len(series)), 171 axes=axes, 172 **anno_args, 173 )
Plot a grouped bar plot.
176class StackedKwargs(TypedDict): 177 """TypedDict for the kwargs used in stacked.""" 178 179 color: Sequence[str] 180 width: Sequence[float | int] 181 label_series: Sequence[bool] 182 zorder: Sequence[int | float | None]
TypedDict for the kwargs used in stacked.
185def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 186 """Plot a stacked bar plot.""" 187 row_count = len(df) 188 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 189 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 190 for i, col in enumerate(df.columns): 191 series = df[col] 192 base = np.where(series >= 0, base_plus, base_minus) 193 foreground = kwargs["color"][i] 194 axes.bar( 195 x=series.index, 196 height=series, 197 bottom=base, 198 color=foreground, 199 width=kwargs["width"][i], 200 zorder=kwargs["zorder"][i], 201 label=col if kwargs["label_series"][i] else f"_{col}_", 202 ) 203 anno_args["foreground"] = foreground 204 annotate_bars( 205 series=series, 206 offset=0, 207 base=base, 208 axes=axes, 209 **anno_args, 210 ) 211 base_plus += np.where(series >= 0, series, 0) 212 base_minus += np.where(series < 0, series, 0)
Plot a stacked bar plot.
215def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 216 """Create a bar plot from the given data. 217 218 Each column in the DataFrame will be stacked on top of each other, 219 with positive values above zero and negative values below zero. 220 221 Args: 222 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 223 **kwargs: BarKwargs - Additional keyword arguments for customization. 224 (see BarKwargs for details) 225 226 Note: This function does not assume all data is timeseries with a PeriodIndex. 227 228 Returns: 229 axes: Axes - The axes for the plot. 230 231 """ 232 # --- check the kwargs 233 report_kwargs(caller=ME, **kwargs) 234 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 235 236 # --- get the data 237 # no call to check_clean_timeseries here, as bar plots are not 238 # necessarily timeseries data. If the data is a Series, it will be 239 # converted to a DataFrame with a single column. 240 df = DataFrame(data) # really we are only plotting DataFrames 241 df, kwargs_d = constrain_data(df, **kwargs) 242 item_count = len(df.columns) 243 244 # --- deal with string indices 245 saved_strings = map_stringindex(df) 246 if saved_strings is not None: 247 df = saved_strings[0] 248 249 # --- deal with complete PeriodIndex indices 250 saved_pi = map_periodindex(df) 251 if saved_pi is not None: 252 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 253 254 # --- set up the default arguments 255 chart_defaults: dict[str, bool | int] = { 256 "stacked": False, 257 "max_ticks": DEFAULT_MAX_TICKS, 258 "label_series": item_count > 1, 259 "label_rotation": 0, 260 } 261 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 262 263 bar_defaults = { 264 "color": get_color_list(item_count), 265 "width": get_setting("bar_width"), 266 "label_series": item_count > 1, 267 "zorder": None, 268 } 269 above = kwargs_d.get("above", False) 270 anno_args: AnnoKwargs = { 271 "annotate": kwargs_d.get("annotate", False), 272 "fontsize": kwargs_d.get("fontsize", "small"), 273 "fontname": kwargs_d.get("fontname", "Helvetica"), 274 "rotation": kwargs_d.get("rotation", 0), 275 "rounding": kwargs_d.get("rounding", True), 276 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 277 "above": above, 278 } 279 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 280 281 # --- plot the data 282 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 283 if chart_args["stacked"]: 284 stacked(axes, df, anno_args, **bar_args) 285 else: 286 grouped(axes, df, anno_args, **bar_args) 287 288 # --- handle index labels and rotation 289 if saved_strings is not None: 290 axes.set_xticks(range(len(saved_strings[1]))) 291 axes.set_xticklabels(saved_strings[1]) 292 elif saved_pi is not None: 293 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 294 plt.xticks(rotation=chart_args["label_rotation"]) 295 296 return axes
Create a bar plot from the given data.
Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)
Note: This function does not assume all data is timeseries with a PeriodIndex.
Returns: axes: Axes - The axes for the plot.